forked from jiuyuan/InfiniTensor
Add: nnet code
This commit is contained in:
parent
1205240218
commit
b7e2096a26
|
@ -4,10 +4,10 @@ include(CMakeDependentOption)
|
|||
project(InfiniTensor C CXX)
|
||||
|
||||
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||
option(BUILD_TEST "Build tests" OFF)
|
||||
option(BUILD_TEST "Build tests" ON)
|
||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
|
||||
|
||||
set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||
|
||||
|
@ -31,16 +31,24 @@ if(OpenMP_CXX_FOUND)
|
|||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
# add_subdirectory(3rd-party/pybind11)
|
||||
# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||
|
||||
include_directories(include)
|
||||
|
||||
# # Pybind11
|
||||
# add_subdirectory(3rd-party/pybind11)
|
||||
# include_directories(3rd-party/pybind11/include)
|
||||
# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||
|
||||
# nlohmann_json
|
||||
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||
|
||||
if(BUILD_TEST)
|
||||
set(BUILD_GMOCK off)
|
||||
set(INSTALL_GTEST off)
|
||||
# TODO: chekc set
|
||||
set(BUILD_GMOCK
|
||||
OFF
|
||||
CACHE BOOL "Do not build gmock" FORCE)
|
||||
set(INSTALL_GTEST
|
||||
OFF
|
||||
CACHE BOOL "Do not install gtest" FORCE)
|
||||
add_subdirectory(3rd-party/googletest)
|
||||
include_directories(3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
|
@ -59,13 +67,26 @@ add_library(InfiniTensor SHARED ${SRC})
|
|||
# pybind11_add_module(infini MODULE ${FFI})
|
||||
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||
|
||||
if(BUILD_TEST)
|
||||
enable_testing()
|
||||
file(GLOB_RECURSE TEST_SOURCES test/*.cc)
|
||||
function(build_test files)
|
||||
# Non-recursive glob for skip failed tests
|
||||
file(GLOB TEST_SOURCES ${files})
|
||||
foreach(testsourcefile ${TEST_SOURCES})
|
||||
get_filename_component(testname ${testsourcefile} NAME_WE)
|
||||
add_executable(${testname} ${testsourcefile})
|
||||
target_link_libraries(${testname} InfiniTensor gtest_main gtest)
|
||||
target_link_libraries(${testname} InfiniTensor GTest::gtest_main)
|
||||
add_test(NAME ${testname} COMMAND ${testname})
|
||||
endforeach(testsourcefile ${TEST_SOURCES})
|
||||
endfunction()
|
||||
|
||||
if(BUILD_TEST)
|
||||
enable_testing()
|
||||
if(BUILD_TEST_CORE)
|
||||
build_test(test/core/*.cc)
|
||||
endif()
|
||||
if(BUILD_TEST_PET)
|
||||
build_test(test/pet/*.cc)
|
||||
endif()
|
||||
if(BUILD_TEST_EINNET)
|
||||
build_test(test/nnet/*.cc)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class MatchComputationKernel : public Pass {
|
||||
public:
|
||||
MatchComputationKernel(Derivator &derivator)
|
||||
: Pass(derivator, "MatchComputationKernel") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class MatchMemBoundKernel : public Pass {
|
||||
public:
|
||||
MatchMemBoundKernel(Derivator &derivator)
|
||||
: Pass(derivator, "MatchMemBoundKernel") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
#include "nnet/derivator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Pass {
|
||||
private:
|
||||
VecExpr transformations;
|
||||
|
||||
protected:
|
||||
Derivator &derivator;
|
||||
string passName;
|
||||
/**
|
||||
* @brief // False if does not add log in Derivator. It should be false for
|
||||
* single Pass test to avoid mismatch of passInfos and passMsgs due to
|
||||
* different number of "run" and "nextStep".
|
||||
*/
|
||||
bool enableLogging, enableDebug;
|
||||
|
||||
virtual void transform(Formula &origin, int depth, Expr &rCur) = 0;
|
||||
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
|
||||
const string &ruleInfo = "");
|
||||
|
||||
Var getNewVar();
|
||||
string newTensorName();
|
||||
|
||||
private:
|
||||
void initialize(Formula &origin, const Expr &rCur);
|
||||
void finalize();
|
||||
|
||||
public:
|
||||
Pass(Derivator &derivator, const string &passName);
|
||||
virtual ~Pass();
|
||||
|
||||
void run(Formula &origin, int dfsDepth, Expr &rCur);
|
||||
void setEnableLogging(bool value);
|
||||
void setEnableDebug(bool value);
|
||||
const VecExpr &getTransformations();
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
#include "nnet/ReplaceKit.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule1VariableSplit : public Pass {
|
||||
public:
|
||||
Rule1VariableSplit(Derivator &derivator)
|
||||
: Pass(derivator, "Rule1VariableSplit") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
vector<Replace> getSplitableVar(const RangeOp &rangeOp);
|
||||
Expr replaceIters(Expr cur, const Replace &replace);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,29 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
#include "nnet/ReplaceKit.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule2VariableMerging : public Pass {
|
||||
private:
|
||||
map<int, vector<Var>> substituteRules;
|
||||
|
||||
public:
|
||||
Rule2VariableMerging(Derivator &derivator)
|
||||
: Pass(derivator, "Rule2VariableMerging") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
|
||||
vector<Replace> getMergableReplaces(RangeOp rangeOp, int depth);
|
||||
optional<Replace> getReplaceMergingTwoLoopIters(const RangeOp &rangeOp,
|
||||
pair<Iterator, int> pairA,
|
||||
pair<Iterator, int> pairB,
|
||||
const IteratorTable &exprIT,
|
||||
int tensorID);
|
||||
optional<Replace> getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
|
||||
pair<Iterator, int> pa,
|
||||
pair<Iterator, int> pb);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule3StageSplit : public Pass {
|
||||
private:
|
||||
map<int, vector<Var>> substituteRules;
|
||||
|
||||
public:
|
||||
Rule3StageSplit(Derivator &derivator)
|
||||
: Pass(derivator, "Rule3StageSplit") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
vector<vector<Var>> getSplitSummationIters(RangeOp rangeOp);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule4StageMerging : public Pass {
|
||||
bool success, mergeStageWithCalc;
|
||||
|
||||
public:
|
||||
Rule4StageMerging(Derivator &derivator)
|
||||
: Pass(derivator, "Rule4StageMerging"), success(false),
|
||||
mergeStageWithCalc(false) {}
|
||||
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
|
||||
bool mergeStageWithCalc = false);
|
||||
bool isSuccessful();
|
||||
void setMergeStageWithCalc(bool value);
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule5RangeRelaxation : public Pass {
|
||||
public:
|
||||
Rule5RangeRelaxation(Derivator &derivator)
|
||||
: Pass(derivator, "Rule5RangeRelaxation") {}
|
||||
Expr rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule6KenerlMatching : public Pass {
|
||||
public:
|
||||
Rule6KenerlMatching(Derivator &derivator)
|
||||
: Pass(derivator, "Rule6KenerlMatching") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
// RE: seperating this func is a choice.
|
||||
VecExpr matchElementWise(const RangeOp &rangeOp);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule7DLT : public Pass {
|
||||
public:
|
||||
Rule7DLT(Derivator &derivator) : Pass(derivator, "Rule7DLT") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
Expr buildDLTSingleRangeOp(const RangeOp &original, const Expr &newSummand);
|
||||
vector<int> getFactors();
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,48 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
#include "nnet/ReplaceKit.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule8GuidedDLT : public Pass {
|
||||
public:
|
||||
Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {}
|
||||
VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur,
|
||||
bool debug = false);
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
/**
|
||||
* @brief If only one row miss match (more iterators mismatch), directly do
|
||||
* data layout construction according to the IT.
|
||||
*
|
||||
* @return Expr Return nullptr if failed.
|
||||
*/
|
||||
Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch,
|
||||
const IteratorTable &exprIT, const Pattern &pattern);
|
||||
/**
|
||||
* @brief Check whether two iterators overlap each other. If overlapping, we
|
||||
* cannot simply reconstruct the tensor into a new one by seperate all
|
||||
* iterators into different dimensions.
|
||||
*/
|
||||
bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT,
|
||||
int tensorID);
|
||||
/**
|
||||
* @brief Only product of two tensors can be guided DLTed.
|
||||
*
|
||||
* @param cur
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
bool statisfyGuidedDLT(RangeOp cur) const;
|
||||
/**
|
||||
* @brief Deal with output DLT mismatch only.
|
||||
*/
|
||||
Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch,
|
||||
const IteratorTable &exprIT,
|
||||
const Pattern &pattern);
|
||||
Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace,
|
||||
vector<Var> tensorDimAxes, vector<int> newShape);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule90TwoStageElementWise : public Pass {
|
||||
public:
|
||||
Rule90TwoStageElementWise(Derivator &derivator)
|
||||
: Pass(derivator, "Rule90TwoStageElementWise") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
VecExpr matchTwoStageElementWise(const RangeOp &rangeOp);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule91MergeStagesWithSum : public Pass {
|
||||
public:
|
||||
Rule91MergeStagesWithSum(Derivator &derivator)
|
||||
: Pass(derivator, "Rule91MergeStagesWithSum") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include "nnet/Pass/Pass.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Rule9RangeMagnify : public Pass {
|
||||
public:
|
||||
Rule9RangeMagnify(Derivator &derivator)
|
||||
: Pass(derivator, "Rule9RangeMagnify") {}
|
||||
|
||||
private:
|
||||
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
#include "nnet/expr.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
struct Replace {
|
||||
int iteratorType;
|
||||
vector<Var> oldIters; // i_1, ...
|
||||
vector<Var> newIters; // j_1, ...
|
||||
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
|
||||
VecExpr psis; // i_1=\psi_1(j_1, ...)
|
||||
vector<VarRangePair> newVarRanges;
|
||||
|
||||
bool isReplaced(Var var) const {
|
||||
for (const auto &iter : oldIters)
|
||||
if (iter->equal(var))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
string toReadable() const {
|
||||
string ret = "Old iters: " + serializeVec(oldIters) +
|
||||
", new iters: " + serializeVec(newIters);
|
||||
ret += " phis: " + serializeVec(phis) + " psis: " + serializeVec(psis);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class ReplaceKit {
|
||||
public:
|
||||
static RangeOp replaceRangeOpIterator(const RangeOp &rangeOp,
|
||||
const Replace &replace,
|
||||
const Expr &replacedSummand);
|
||||
static Subscript buildSubscirptForLoopVarReplace(const RangeOp &inner,
|
||||
const Replace &replace);
|
||||
static RangeOp buildDLTOuterRangeOp(const RangeOp &original,
|
||||
const Subscript &subscriptedNewRangeOp);
|
||||
static Expr replaceMultipleExprs(const Expr &cur,
|
||||
const vector<Var> &patterns,
|
||||
const VecExpr &replacements,
|
||||
bool simplify = true);
|
||||
static Expr replaceExpr(const Expr &cur, const Expr &pattern,
|
||||
const Expr &replacement);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "nnet/Visitor/StrideVisitor.h"
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class AsTVMVisitor : public Functor<std::string(void)> {
|
||||
private:
|
||||
int nStage = 0, curStage = -1;
|
||||
std::unordered_map<std::string, int> offset;
|
||||
std::vector<std::string> inputs;
|
||||
std::string output;
|
||||
std::vector<std::string> pythonVars;
|
||||
std::vector<std::vector<int>> inputShapes;
|
||||
std::vector<int> outputShape;
|
||||
std::string stmts;
|
||||
|
||||
public:
|
||||
std::string getStmts() const;
|
||||
|
||||
const std::vector<std::string> &getInputs() const { return inputs; }
|
||||
const std::string &getOutput() const { return output; }
|
||||
|
||||
const std::vector<std::vector<int>> &getInputShapes() const {
|
||||
return inputShapes;
|
||||
}
|
||||
const std::vector<int> &getOutputShape() const { return outputShape; }
|
||||
|
||||
std::string visit_(const Constant &c) override;
|
||||
std::string visit_(const BinaryOp &c) override;
|
||||
std::string visit_(const Func &c) override;
|
||||
std::string visit_(const RangeOp &c) override;
|
||||
std::string visit_(const Subscript &c) override;
|
||||
std::string visit_(const Var &c) override;
|
||||
std::string visit_(const Tensor &c) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class CheckOOBVisitor : public ExprTreeVisitor {
|
||||
RangeOp rangeOp;
|
||||
bool detect = false;
|
||||
|
||||
public:
|
||||
CheckOOBVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 0, 0, _verobse) {}
|
||||
void visit_(const Subscript &c) override;
|
||||
|
||||
/**
|
||||
* @brief
|
||||
* @return true If there is OOB
|
||||
* @return false If there is no OOB
|
||||
*/
|
||||
bool checkRangeOp(const RangeOp &_rangeOp);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// Clone ExprNodes in a stage except Tensor, Var, and Constant nodes.
|
||||
class CloneMutator : public Mutator {
|
||||
public:
|
||||
CloneMutator() : Mutator(false) {}
|
||||
Expr visit_(const Constant &c) override;
|
||||
Expr visit_(const Var &c) override;
|
||||
Expr visit_(const Tensor &c) override;
|
||||
Expr clone(const Expr &c) { return dispatch(c); }
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class CompareMultiFormulasVisitor : public ExprTreeVisitor {
|
||||
vector<VarRangePair> newSumVarRanges;
|
||||
RangeOp newRangeOp;
|
||||
|
||||
public:
|
||||
CompareMultiFormulasVisitor() : ExprTreeVisitor() {}
|
||||
bool compare(const VecExpr &roots);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class CountRoutineVisitor : public ExprTreeVisitor {
|
||||
private:
|
||||
vector<int> cnts;
|
||||
|
||||
public:
|
||||
CountRoutineVisitor(int _verobse = 0)
|
||||
: ExprTreeVisitor(1, 1, 1, 1, _verobse) {}
|
||||
void visit_(const Tensor &c) override;
|
||||
vector<int> count(const Expr &root);
|
||||
bool match(const Expr &root, int nMatmul = 0, int nConv = 0,
|
||||
int nElement = 0, int nSg2bmm = 0, int nLongformerGBMM = 0);
|
||||
};
|
||||
} // namespace nnet
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class FullPrinterVisitor : public ExprTreeVisitor {
|
||||
private:
|
||||
vector<tuple<string, Routine, Tensor>> q;
|
||||
|
||||
public:
|
||||
FullPrinterVisitor(int _verobse = 0)
|
||||
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
||||
void visit_(const Tensor &c) override;
|
||||
|
||||
string print(const Expr &root);
|
||||
/**
|
||||
* @brief Get all tensors & OPs in a reversed order
|
||||
*
|
||||
* @param root
|
||||
* @return vector<<Output TensorName, RoutineNode, output tensor in NNet>>
|
||||
*/
|
||||
const vector<tuple<string, Routine, Tensor>> &traverse(const Expr &root);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// Get all tensors in the stage
|
||||
class GetTensorsVisitor : public ExprTreeVisitor {
|
||||
private:
|
||||
unordered_map<string, Tensor> tensors;
|
||||
|
||||
void visit_(const Tensor &c) override;
|
||||
|
||||
public:
|
||||
GetTensorsVisitor(int _verobse = 0)
|
||||
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
||||
auto get(const Expr &c) {
|
||||
dispatch(c);
|
||||
return tensors;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,31 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// Calculate hash for a normal form, starting at a RangeOp
|
||||
class HashVisitor : public Functor<HashType(void)> {
|
||||
inline const static HashType BKDR_SEED[] = {131, 313, 10007, 65599};
|
||||
|
||||
PtrUmap<Iterator, int> varHash;
|
||||
int nLoopVars = 0;
|
||||
PtrUmap<Iterator, int> name2id;
|
||||
vector<int> rootId;
|
||||
vector<bool> haveAlias;
|
||||
int nVars = 0;
|
||||
vector<HashType> power;
|
||||
|
||||
private:
|
||||
HashType visit_(const Constant &c) override;
|
||||
HashType visit_(const BinaryOp &c) override;
|
||||
HashType visit_(const RangeOp &c) override;
|
||||
HashType visit_(const Subscript &c) override;
|
||||
HashType visit_(const Tensor &c) override;
|
||||
HashType visit_(const Var &c) override;
|
||||
|
||||
public:
|
||||
HashVisitor(int _verobse = 0) : Functor(_verobse) {}
|
||||
HashType getHash(const Expr &c);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,23 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class InputVisitor : public ExprTreeVisitor {
|
||||
vector<Tensor> inputs;
|
||||
|
||||
public:
|
||||
int nInputs = 0;
|
||||
InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
|
||||
void visit_(const Tensor &c) override;
|
||||
|
||||
/**
|
||||
* @brief Get the all inputs in the netsed stages
|
||||
*/
|
||||
vector<Tensor> getInputs(const RangeOp &_rangeOp) {
|
||||
dispatch(_rangeOp);
|
||||
return inputs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,55 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Interpreter : public Functor<int()> {
|
||||
public:
|
||||
using ttype = int; // Test data type
|
||||
using rtype = int; // Return data type
|
||||
using Position = vector<int>;
|
||||
using Inputs = unordered_map<string, Ref<vector<ttype>>>;
|
||||
using Iteration = PtrUmap<Var, int>;
|
||||
|
||||
private:
|
||||
// cache the input value
|
||||
Inputs inputs;
|
||||
vector<Iteration> iterations;
|
||||
vector<Position> positions;
|
||||
|
||||
rtype visit_(const Constant &c) override;
|
||||
rtype visit_(const BinaryOp &c) override;
|
||||
rtype visit_(const RangeOp &c) override;
|
||||
rtype visit_(const Subscript &c) override;
|
||||
rtype visit_(const Var &c) override;
|
||||
rtype visit_(const Tensor &c) override;
|
||||
// int visit_(const Func &c); // Future work
|
||||
|
||||
static Inputs genInputStartingFromZero(const RangeOp &range);
|
||||
|
||||
public:
|
||||
Interpreter(Inputs _inputs, int _verbose = 0)
|
||||
: Functor(_verbose), inputs(_inputs) {}
|
||||
Interpreter(RangeOp range, int _verbose = 0);
|
||||
|
||||
/**
|
||||
* @brief Calculate the output at specified poistions
|
||||
*
|
||||
* @param expr The expression to be calculated.
|
||||
* @param poses Positions of output.
|
||||
* @return vector<int> Value of output.
|
||||
*/
|
||||
vector<rtype> interpret(const Expr &expr, const vector<Position> &poses);
|
||||
/**
|
||||
* @brief Calculate the output at equally spaced positions
|
||||
*
|
||||
* @param expr The expression to be calculated.
|
||||
* @param nPoses The number of calculated output positions.
|
||||
* @return vector<int> Value of output.
|
||||
*/
|
||||
vector<rtype> interpretUniformSample(const RangeOp &range,
|
||||
int nPoses = 100);
|
||||
vector<rtype> interpretAllOutput(const RangeOp &range);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class MatchReshapeVisitor : public Functor<bool(void)> {
|
||||
private:
|
||||
PtrMap<Iterator, int> _coefficient;
|
||||
|
||||
public:
|
||||
bool visit_(const RangeOp &c) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,60 @@
|
|||
#pragma once
|
||||
#include "nnet/Visitor/StrideVisitor.h"
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class MatchTableVisitor
|
||||
: public Functor<void(const Tensor &, int dim, optional<int> stride)> {
|
||||
private:
|
||||
// Var -> {(tensor, dim)}
|
||||
Appearance appearance;
|
||||
vector<Tensor> tensors;
|
||||
vector<Subscript> subscripts;
|
||||
StrideTable strideTable;
|
||||
PtrMap<Iterator, vector<vector<int>>>
|
||||
strideInDim; // [Iterator][tensorID][dim]=stride
|
||||
|
||||
// Intermediate variable
|
||||
// product of a sub-exprtree: Stride has to be done in two DFS
|
||||
SubexprSride subexprStride;
|
||||
bool hasUnsupportedOp = false;
|
||||
|
||||
public:
|
||||
MatchTableVisitor(int _verobse = 0) : Functor(_verobse) {}
|
||||
void visit_(const BinaryOp &c, const Tensor &tensor, int dim,
|
||||
optional<int> stride) override;
|
||||
void visit_(const Subscript &c, const Tensor &tensor, int dim,
|
||||
optional<int> stride) override;
|
||||
void visit_(const Var &c, const Tensor &tensor, int dim,
|
||||
optional<int> stride) override;
|
||||
void visit_(const Constant &c, const Tensor &tensor, int dim,
|
||||
optional<int> stride) override;
|
||||
// void visit_(const Tensor &c, const Tensor &tensor) override;
|
||||
|
||||
[[nodiscard]] bool operator()(const RangeOp &e) {
|
||||
hasUnsupportedOp = false;
|
||||
// get the location and stride of each iterator
|
||||
auto mulOp = as<BinaryOpNode>(e->getSummand());
|
||||
// TODO [feature]: support complex index exprs
|
||||
if (!mulOp || mulOp->getOpType() != OpType::Mul) {
|
||||
nnet_unimplemented_continue();
|
||||
return false;
|
||||
}
|
||||
StrideVisitor strideVisitor(0);
|
||||
subexprStride = strideVisitor.getFormulaStride(e);
|
||||
dispatch(mulOp->getLhs(), nullptr, 0, 0);
|
||||
dispatch(mulOp->getRhs(), nullptr, 0, 0);
|
||||
subscripts.emplace_back(as<SubscriptNode>(mulOp->getLhs()));
|
||||
subscripts.emplace_back(as<SubscriptNode>(mulOp->getRhs()));
|
||||
assert(tensors.size() == subscripts.size());
|
||||
assert(tensors.size() < 5);
|
||||
return !hasUnsupportedOp;
|
||||
}
|
||||
|
||||
auto getResult() const {
|
||||
return tuple(appearance, tensors, strideTable, subscripts);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class MatmulTransposeMutator : public Mutator {
|
||||
Derivator &derivator;
|
||||
|
||||
public:
|
||||
MatmulTransposeMutator(Derivator &derivator)
|
||||
: Mutator(1), derivator(derivator) {}
|
||||
VecExpr transpose(const Tensor &tensor);
|
||||
|
||||
private:
|
||||
Tensor transposeInput(const Tensor &tensor);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class MergeMemboundMutator : public Mutator {
|
||||
VecExpr kernels;
|
||||
int curDepth; // from the last one to the first one
|
||||
Expr visit_(const Tensor &c) override;
|
||||
// FIXME: duplicate code
|
||||
Expr rule4StageMerging(Expr &rCur, bool mergeStageWithCalc);
|
||||
bool checkEmpty();
|
||||
|
||||
public:
|
||||
MergeMemboundMutator(const VecExpr &kernels)
|
||||
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
|
||||
Expr merge(bool allowEmptyMembound = false);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,43 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
/**
|
||||
* @brief Since the output positions of operators always start from 0, we have
|
||||
* to offset them if the the boundary expression of is not 0.
|
||||
*/
|
||||
class PatternMatcher : public Functor<void(void)> {
|
||||
private:
|
||||
Derivator &derivator;
|
||||
bool hasNonZeroRange;
|
||||
const RangeOp originalCur;
|
||||
|
||||
public:
|
||||
PatternMatcher(Derivator &derivator, const RangeOp &cur);
|
||||
/**
|
||||
* @brief Get the Cur whose loop vars are all offset to [0, x). Since
|
||||
* operator outputs start from 0, RangeOp has to be aligned.
|
||||
*/
|
||||
RangeOp getOffsetCur();
|
||||
/**
|
||||
* @brief Add outer RangeOp to map the original positions to the new
|
||||
* positions staring from 0.
|
||||
*
|
||||
* @param exprs Tensors from matched exprs
|
||||
*/
|
||||
VecExpr applyWrapper(const VecExpr &exprs);
|
||||
|
||||
VecExpr matchWithPattern(const RangeOp &rangeOp, const Pattern &pattern);
|
||||
|
||||
private:
|
||||
VecExpr matchKernel(const Pattern &pattern, const RangeOp &rangeOp,
|
||||
IteratorTable &exprIT);
|
||||
// get reverse tensor and iterator map ([pattern tensor/iter ID] ->
|
||||
// real)
|
||||
Expr matchKernelWithTensorMap(const Pattern &pattern,
|
||||
const RangeOp &rangeOp,
|
||||
IteratorTable &exprIT);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class RangeMagnifyVisitor : public Mutator {
|
||||
vector<VarRangePair> newSumVarRanges;
|
||||
RangeOp newRangeOp;
|
||||
|
||||
public:
|
||||
RangeMagnifyVisitor() : Mutator(0) {}
|
||||
Expr visit_(const RangeOp &c) override;
|
||||
Expr visit_(const Subscript &c) override;
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param root
|
||||
* @param _newSumVarRanges
|
||||
* @return RangeOp nullptr if failed to magnify
|
||||
*/
|
||||
RangeOp magnify(const RangeOp &root,
|
||||
const vector<VarRangePair> &_newSumVarRanges);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
using RangeMap = PtrMap<Iterator, Range>;
|
||||
class RangeRelaxFunctor : public Functor<RangeMap()> {
|
||||
RangeOp rangeOp;
|
||||
|
||||
public:
|
||||
RangeRelaxFunctor(RangeOp _rangeOp) : Functor(false), rangeOp(_rangeOp) {}
|
||||
RangeMap visit_(const BinaryOp &c) override;
|
||||
RangeMap visit_(const RangeOp &c) override;
|
||||
RangeMap visit_(const Subscript &c) override;
|
||||
RangeMap intersectRangeMaps(const RangeMap &a, const RangeMap &b);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// Replace node according to its address the summand
|
||||
// Only subscript and tensor are supported now.
|
||||
class ReplaceNodeMutator : public Mutator {
|
||||
int nSubscripts = 0;
|
||||
ExprNode *target;
|
||||
Expr replacement;
|
||||
|
||||
public:
|
||||
ReplaceNodeMutator() : Mutator(0) {}
|
||||
Expr visit_(const Subscript &c) override;
|
||||
Expr visit_(const Tensor &c) override;
|
||||
Expr replace(const Expr &root, ExprNode *_target, const Expr &_replace);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class ReplaceVariable : public Mutator {
|
||||
VecExpr patterns, replacements;
|
||||
map<HashType, int> patternHash;
|
||||
|
||||
public:
|
||||
ReplaceVariable(Expr _pattern, Expr _replacement) : Mutator(false) {
|
||||
set({_pattern}, {_replacement});
|
||||
}
|
||||
ReplaceVariable(const map<string, pair<Expr, Expr>> &mapping)
|
||||
: Mutator(false) {
|
||||
VecExpr _patterns, _replacements;
|
||||
for (const auto &[_, v] : mapping) {
|
||||
_patterns.emplace_back(v.first);
|
||||
_replacements.emplace_back(v.second);
|
||||
}
|
||||
set(_patterns, _replacements);
|
||||
}
|
||||
Expr visit_(const BinaryOp &c) override;
|
||||
// NOT recur to the next stage
|
||||
Expr visit_(const RangeOp &c) override;
|
||||
Expr visit_(const Var &c) override;
|
||||
|
||||
private:
|
||||
void set(VecExpr _pattern, VecExpr _replacement);
|
||||
Expr match(const Expr &c);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,52 @@
|
|||
#pragma once
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "nnet/visitor.h"
|
||||
#include <memory>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Serializer : public Functor<string()> {
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
private:
|
||||
static constexpr int VERSION{1};
|
||||
std::unique_ptr<json> jPtr;
|
||||
json &j;
|
||||
static int id;
|
||||
|
||||
string visit_(const Constant &c) override;
|
||||
string visit_(const BinaryOp &c) override;
|
||||
string visit_(const RangeOp &c) override;
|
||||
string visit_(const Subscript &c) override;
|
||||
string visit_(const Var &c) override;
|
||||
string visit_(const Tensor &c) override;
|
||||
string dispatchRoutine(const Routine &c);
|
||||
|
||||
Expr buildExprTree(string key);
|
||||
Routine buildRoutine(string key);
|
||||
|
||||
public:
|
||||
Serializer(int _verobse = 0);
|
||||
virtual ~Serializer();
|
||||
|
||||
/**
|
||||
* @brief Serialize the given expression to json file
|
||||
*
|
||||
* @param expr The expression to be serialized
|
||||
* @param filePath The path of json file to be output
|
||||
* @param msg Message of derivation
|
||||
* @return bool Whether the serialization succeed
|
||||
*/
|
||||
bool serialize(const Expr &expr, const string &filePath,
|
||||
const string &msg = "");
|
||||
|
||||
/**
|
||||
* @brief Deserialize the given json file to expression
|
||||
*
|
||||
* @param filePath The path to file to be deserialized
|
||||
* @return Expression deserialized from the given json file
|
||||
*/
|
||||
Expr deserialize(const string &filePath);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "nnet/Visitor/StrideVisitor.h"
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// Simplify a index expression tree
|
||||
class SimplifyExprVisitor : public Functor<void(optional<int> stride)> {
|
||||
private:
|
||||
SubexprSride subexprStride;
|
||||
int constant;
|
||||
PtrMap<Iterator, int> strides; // [var]=strides
|
||||
|
||||
map<pair<Iterator, int>, int, RefValueLess<pair<Iterator, int>>> divStrides,
|
||||
modStrides; // 3*(i%8): [<i,8>]=3
|
||||
|
||||
// For divde and modulo with expr as dividend: 3*((i+1)%8): [<i+1,8>]=3
|
||||
map<pair<Expr, int>, int, RefAddrLess<pair<Expr, int>>> divExprStrides,
|
||||
modExprStrides;
|
||||
|
||||
public:
|
||||
SimplifyExprVisitor() : Functor(0) {}
|
||||
void visit_(const BinaryOp &c, optional<int> stride) override;
|
||||
void visit_(const Var &c, optional<int> stride) override;
|
||||
void visit_(const Constant &c, optional<int> stride) override;
|
||||
PtrMap<Iterator, int> getStrides(const Expr &expr);
|
||||
// TODO [refactor]: move this to SimplifyFormulaMutator as a member func
|
||||
// this class should be get coefficients in a expr
|
||||
Expr simplify(const Expr &expr);
|
||||
int getConstant(const Expr &expr);
|
||||
pair<PtrMap<Iterator, int>, int> getStridesConstant(const Expr &expr);
|
||||
optional<Range> getExprRange(const Expr &expr, const RangeOp &rangeOp);
|
||||
PtrMap<Iterator, int> getStrides() { return strides; }
|
||||
const auto &getDivStrides() { return divStrides; }
|
||||
const auto &getModStrides() { return modStrides; }
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "nnet/Visitor/StrideVisitor.h"
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// Simplify all indexes in subscripts in an expression tree
|
||||
class SimplifyFormulaMutator : public Mutator {
|
||||
int nSubscripts = 0;
|
||||
|
||||
public:
|
||||
SimplifyFormulaMutator() : Mutator(0) {}
|
||||
Expr visit_(const Subscript &c) override;
|
||||
// Expr visit_(const BinaryOp &c) override;
|
||||
Expr simplify(const Expr &expr);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "nnet/visitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
using SubexprSride = map<const ExprNode *, optional<int>>;
|
||||
class StrideVisitor : public Functor<optional<int>(void)> {
|
||||
private:
|
||||
SubexprSride subexprStride;
|
||||
|
||||
public:
|
||||
StrideVisitor(int _verobse = 0) : Functor(_verobse) {}
|
||||
optional<int> visit_(const BinaryOp &c) override;
|
||||
optional<int> visit_(const Subscript &c) override;
|
||||
optional<int> visit_(const Var &c) override;
|
||||
optional<int> visit_(const Constant &c) override;
|
||||
// void visit_(const Tensor &c, const Tensor &tensor) override;
|
||||
|
||||
auto getFormulaStride(const RangeOp &e) {
|
||||
subexprStride.clear();
|
||||
// get the location and stride of each iterator
|
||||
auto mulOp = as<BinaryOpNode>(e->getSummand());
|
||||
// TODO [feature]: support complex index exprs
|
||||
if (!mulOp || mulOp->getOpType() != OpType::Mul)
|
||||
nnet_unimplemented_continue();
|
||||
dispatch(mulOp->getLhs());
|
||||
dispatch(mulOp->getRhs());
|
||||
return subexprStride;
|
||||
}
|
||||
|
||||
[[nodiscard]] auto getExprStride(const Expr &e) {
|
||||
subexprStride.clear();
|
||||
dispatch(e);
|
||||
return subexprStride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,77 @@
|
|||
#pragma once
|
||||
#include "dbg.h"
|
||||
#include <cassert>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace nnet {
|
||||
using std::dynamic_pointer_cast;
|
||||
using std::endl;
|
||||
using std::list;
|
||||
using std::make_pair;
|
||||
using std::make_shared;
|
||||
using std::make_tuple;
|
||||
using std::map;
|
||||
using std::max;
|
||||
using std::min;
|
||||
using std::nullopt;
|
||||
using std::optional;
|
||||
using std::pair;
|
||||
using std::set;
|
||||
using std::shared_ptr;
|
||||
using std::string;
|
||||
using std::tie;
|
||||
using std::to_string;
|
||||
using std::tuple;
|
||||
using std::unique_ptr;
|
||||
using std::unordered_map;
|
||||
template <typename T> using uset = std::unordered_set<T>;
|
||||
using std::vector;
|
||||
using std::weak_ptr;
|
||||
|
||||
// Aliases
|
||||
using dtype = float;
|
||||
using HashType = int;
|
||||
|
||||
template <typename T> struct ptr_less {
|
||||
bool operator()(const T &lhs, const T &rhs) const { return *lhs < *rhs; }
|
||||
};
|
||||
|
||||
template <typename T> struct ptr_hash {
|
||||
size_t operator()(const T &lhs) const {
|
||||
return std::hash<decltype(*lhs)>()(*lhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct ptr_equal {
|
||||
bool operator()(const T &lhs, const T &rhs) const { return *lhs == *rhs; }
|
||||
};
|
||||
|
||||
static inline HashType genhash(HashType a, HashType b) {
|
||||
return (a * 10007 + b + 12345) % 1000000007;
|
||||
}
|
||||
|
||||
static inline HashType genhash(string s) {
|
||||
HashType ret = 0;
|
||||
for (auto c : s)
|
||||
ret = genhash(ret, c);
|
||||
return ret;
|
||||
}
|
||||
|
||||
#define nnet_unimplemented_halt() \
|
||||
{ assert(!"Unimplemented"); }
|
||||
|
||||
#define nnet_unimplemented_continue() \
|
||||
{ dbg("Unimplemented"); }
|
||||
|
||||
#define nnet_assert(expr, msg) assert(((void)(msg), (expr)))
|
||||
|
||||
std::string pointer_to_hex(void *i);
|
||||
} // namespace nnet
|
|
@ -0,0 +1,893 @@
|
|||
/*****************************************************************************
|
||||
|
||||
dbg(...) macro
|
||||
|
||||
License (MIT):
|
||||
|
||||
Copyright (c) 2019 David Peter <mail@david-peter.de>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to
|
||||
deal in the Software without restriction, including without limitation the
|
||||
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
sell copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef DBG_MACRO_DBG_H
|
||||
#define DBG_MACRO_DBG_H
|
||||
|
||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
|
||||
#define DBG_MACRO_UNIX
|
||||
#elif defined(_MSC_VER)
|
||||
#define DBG_MACRO_WINDOWS
|
||||
#endif
|
||||
|
||||
// #ifndef DBG_MACRO_NO_WARNING
|
||||
// #pragma message("WARNING: the 'dbg.h' header is included in your code base")
|
||||
// #endif // DBG_MACRO_NO_WARNING
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <iomanip>
|
||||
#include <ios>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#ifdef DBG_MACRO_UNIX
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#if __cplusplus >= 201703L
|
||||
#define DBG_MACRO_CXX_STANDARD 17
|
||||
#elif __cplusplus >= 201402L
|
||||
#define DBG_MACRO_CXX_STANDARD 14
|
||||
#else
|
||||
#define DBG_MACRO_CXX_STANDARD 11
|
||||
#endif
|
||||
|
||||
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
#endif
|
||||
|
||||
namespace dbg {
|
||||
|
||||
#ifdef DBG_MACRO_UNIX
|
||||
inline bool isColorizedOutputEnabled() {
|
||||
return isatty(fileno(stderr));
|
||||
}
|
||||
#else
|
||||
inline bool isColorizedOutputEnabled() {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct time {};
|
||||
|
||||
namespace pretty_function {
|
||||
|
||||
// Compiler-agnostic version of __PRETTY_FUNCTION__ and constants to
|
||||
// extract the template argument in `type_name_impl`
|
||||
|
||||
#if defined(__clang__)
|
||||
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
|
||||
static constexpr size_t PREFIX_LENGTH =
|
||||
sizeof("const char *dbg::type_name_impl() [T = ") - 1;
|
||||
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
|
||||
#elif defined(__GNUC__) && !defined(__clang__)
|
||||
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
|
||||
static constexpr size_t PREFIX_LENGTH =
|
||||
sizeof("const char* dbg::type_name_impl() [with T = ") - 1;
|
||||
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
|
||||
#elif defined(_MSC_VER)
|
||||
#define DBG_MACRO_PRETTY_FUNCTION __FUNCSIG__
|
||||
static constexpr size_t PREFIX_LENGTH =
|
||||
sizeof("const char *__cdecl dbg::type_name_impl<") - 1;
|
||||
static constexpr size_t SUFFIX_LENGTH = sizeof(">(void)") - 1;
|
||||
#else
|
||||
#error "This compiler is currently not supported by dbg_macro."
|
||||
#endif
|
||||
|
||||
} // namespace pretty_function
|
||||
|
||||
// Formatting helpers
|
||||
|
||||
template <typename T>
|
||||
struct print_formatted {
|
||||
static_assert(std::is_integral<T>::value,
|
||||
"Only integral types are supported.");
|
||||
|
||||
print_formatted(T value, int numeric_base)
|
||||
: inner(value), base(numeric_base) {}
|
||||
|
||||
operator T() const { return inner; }
|
||||
|
||||
const char* prefix() const {
|
||||
switch (base) {
|
||||
case 8:
|
||||
return "0o";
|
||||
case 16:
|
||||
return "0x";
|
||||
case 2:
|
||||
return "0b";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
T inner;
|
||||
int base;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
print_formatted<T> hex(T value) {
|
||||
return print_formatted<T>{value, 16};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
print_formatted<T> oct(T value) {
|
||||
return print_formatted<T>{value, 8};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
print_formatted<T> bin(T value) {
|
||||
return print_formatted<T>{value, 2};
|
||||
}
|
||||
|
||||
// Implementation of 'type_name<T>()'
|
||||
|
||||
template <typename T>
|
||||
const char* type_name_impl() {
|
||||
return DBG_MACRO_PRETTY_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct type_tag {};
|
||||
|
||||
template <int&... ExplicitArgumentBarrier, typename T>
|
||||
std::string get_type_name(type_tag<T>) {
|
||||
namespace pf = pretty_function;
|
||||
|
||||
std::string type = type_name_impl<T>();
|
||||
return type.substr(pf::PREFIX_LENGTH,
|
||||
type.size() - pf::PREFIX_LENGTH - pf::SUFFIX_LENGTH);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string type_name() {
|
||||
if (std::is_volatile<T>::value) {
|
||||
if (std::is_pointer<T>::value) {
|
||||
return type_name<typename std::remove_volatile<T>::type>() + " volatile";
|
||||
} else {
|
||||
return "volatile " + type_name<typename std::remove_volatile<T>::type>();
|
||||
}
|
||||
}
|
||||
if (std::is_const<T>::value) {
|
||||
if (std::is_pointer<T>::value) {
|
||||
return type_name<typename std::remove_const<T>::type>() + " const";
|
||||
} else {
|
||||
return "const " + type_name<typename std::remove_const<T>::type>();
|
||||
}
|
||||
}
|
||||
if (std::is_pointer<T>::value) {
|
||||
return type_name<typename std::remove_pointer<T>::type>() + "*";
|
||||
}
|
||||
if (std::is_lvalue_reference<T>::value) {
|
||||
return type_name<typename std::remove_reference<T>::type>() + "&";
|
||||
}
|
||||
if (std::is_rvalue_reference<T>::value) {
|
||||
return type_name<typename std::remove_reference<T>::type>() + "&&";
|
||||
}
|
||||
return get_type_name(type_tag<T>{});
|
||||
}
|
||||
|
||||
inline std::string get_type_name(type_tag<short>) {
|
||||
return "short";
|
||||
}
|
||||
|
||||
inline std::string get_type_name(type_tag<unsigned short>) {
|
||||
return "unsigned short";
|
||||
}
|
||||
|
||||
inline std::string get_type_name(type_tag<long>) {
|
||||
return "long";
|
||||
}
|
||||
|
||||
inline std::string get_type_name(type_tag<unsigned long>) {
|
||||
return "unsigned long";
|
||||
}
|
||||
|
||||
inline std::string get_type_name(type_tag<std::string>) {
|
||||
return "std::string";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string get_type_name(type_tag<std::vector<T, std::allocator<T>>>) {
|
||||
return "std::vector<" + type_name<T>() + ">";
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
std::string get_type_name(type_tag<std::pair<T1, T2>>) {
|
||||
return "std::pair<" + type_name<T1>() + ", " + type_name<T2>() + ">";
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
std::string type_list_to_string() {
|
||||
std::string result;
|
||||
auto unused = {(result += type_name<T>() + ", ", 0)..., 0};
|
||||
static_cast<void>(unused);
|
||||
|
||||
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||
if constexpr (sizeof...(T) > 0) {
|
||||
#else
|
||||
if (sizeof...(T) > 0) {
|
||||
#endif
|
||||
result.pop_back();
|
||||
result.pop_back();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
std::string get_type_name(type_tag<std::tuple<T...>>) {
|
||||
return "std::tuple<" + type_list_to_string<T...>() + ">";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string get_type_name(type_tag<print_formatted<T>>) {
|
||||
return type_name<T>();
|
||||
}
|
||||
|
||||
// Implementation of 'is_detected' to specialize for container-like types
|
||||
|
||||
namespace detail_detector {
|
||||
|
||||
struct nonesuch {
|
||||
nonesuch() = delete;
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
template <typename...>
|
||||
using void_t = void;
|
||||
|
||||
template <class Default,
|
||||
class AlwaysVoid,
|
||||
template <class...>
|
||||
class Op,
|
||||
class... Args>
|
||||
struct detector {
|
||||
using value_t = std::false_type;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, void_t<Op<Args...>>, Op, Args...> {
|
||||
using value_t = std::true_type;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
|
||||
} // namespace detail_detector
|
||||
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail_detector::
|
||||
detector<detail_detector::nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
namespace detail {
|
||||
|
||||
namespace {
|
||||
using std::begin;
|
||||
using std::end;
|
||||
#if DBG_MACRO_CXX_STANDARD < 17
|
||||
template <typename T>
|
||||
constexpr auto size(const T& c) -> decltype(c.size()) {
|
||||
return c.size();
|
||||
}
|
||||
template <typename T, std::size_t N>
|
||||
constexpr std::size_t size(const T (&)[N]) {
|
||||
return N;
|
||||
}
|
||||
#else
|
||||
using std::size;
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
using detect_begin_t = decltype(detail::begin(std::declval<T>()));
|
||||
|
||||
template <typename T>
|
||||
using detect_end_t = decltype(detail::end(std::declval<T>()));
|
||||
|
||||
template <typename T>
|
||||
using detect_size_t = decltype(detail::size(std::declval<T>()));
|
||||
|
||||
template <typename T>
|
||||
struct is_container {
|
||||
static constexpr bool value =
|
||||
is_detected<detect_begin_t, T>::value &&
|
||||
is_detected<detect_end_t, T>::value &&
|
||||
is_detected<detect_size_t, T>::value &&
|
||||
!std::is_same<std::string,
|
||||
typename std::remove_cv<
|
||||
typename std::remove_reference<T>::type>::type>::value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using ostream_operator_t =
|
||||
decltype(std::declval<std::ostream&>() << std::declval<T>());
|
||||
|
||||
template <typename T>
|
||||
struct has_ostream_operator : is_detected<ostream_operator_t, T> {};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Helper to dbg(…)-print types
|
||||
template <typename T>
|
||||
struct print_type {};
|
||||
|
||||
template <typename T>
|
||||
print_type<T> type() {
|
||||
return print_type<T>{};
|
||||
}
|
||||
|
||||
// Forward declarations of "pretty_print"
|
||||
|
||||
template <typename T>
|
||||
inline void pretty_print(std::ostream& stream, const T& value, std::true_type);
|
||||
|
||||
template <typename T>
|
||||
inline void pretty_print(std::ostream&, const T&, std::false_type);
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<!detail::is_container<const T&>::value &&
|
||||
!std::is_enum<T>::value,
|
||||
bool>::type
|
||||
pretty_print(std::ostream& stream, const T& value);
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const bool& value);
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const char& value);
|
||||
|
||||
template <typename P>
|
||||
inline bool pretty_print(std::ostream& stream, P* const& value);
|
||||
|
||||
template <typename T, typename Deleter>
|
||||
inline bool pretty_print(std::ostream& stream,
|
||||
std::unique_ptr<T, Deleter>& value);
|
||||
|
||||
// template <typename T>
|
||||
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value);
|
||||
|
||||
template <size_t N>
|
||||
inline bool pretty_print(std::ostream& stream, const char (&value)[N]);
|
||||
|
||||
template <>
|
||||
inline bool pretty_print(std::ostream& stream, const char* const& value);
|
||||
|
||||
template <typename... Ts>
|
||||
inline bool pretty_print(std::ostream& stream, const std::tuple<Ts...>& value);
|
||||
|
||||
template <>
|
||||
inline bool pretty_print(std::ostream& stream, const std::tuple<>&);
|
||||
|
||||
template <>
|
||||
inline bool pretty_print(std::ostream& stream, const time&);
|
||||
|
||||
template <typename T>
|
||||
inline bool pretty_print(std::ostream& stream,
|
||||
const print_formatted<T>& value);
|
||||
|
||||
template <typename T>
|
||||
inline bool pretty_print(std::ostream& stream, const print_type<T>&);
|
||||
|
||||
template <typename Enum>
|
||||
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
|
||||
pretty_print(std::ostream& stream, Enum const& value);
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const std::string& value);
|
||||
|
||||
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const std::string_view& value);
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T1, typename T2>
|
||||
inline bool pretty_print(std::ostream& stream, const std::pair<T1, T2>& value);
|
||||
|
||||
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||
|
||||
template <typename T>
|
||||
inline bool pretty_print(std::ostream& stream, const std::optional<T>& value);
|
||||
|
||||
template <typename... Ts>
|
||||
inline bool pretty_print(std::ostream& stream,
|
||||
const std::variant<Ts...>& value);
|
||||
|
||||
#endif
|
||||
|
||||
template <typename Container>
|
||||
inline typename std::enable_if<detail::is_container<const Container&>::value,
|
||||
bool>::type
|
||||
pretty_print(std::ostream& stream, const Container& value);
|
||||
|
||||
// Specializations of "pretty_print"
|
||||
|
||||
template <typename T>
|
||||
inline void pretty_print(std::ostream& stream, const T& value, std::true_type) {
|
||||
stream << value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void pretty_print(std::ostream&, const T&, std::false_type) {
|
||||
static_assert(detail::has_ostream_operator<const T&>::value,
|
||||
"Type does not support the << ostream operator");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<!detail::is_container<const T&>::value &&
|
||||
!std::is_enum<T>::value,
|
||||
bool>::type
|
||||
pretty_print(std::ostream& stream, const T& value) {
|
||||
pretty_print(stream, value,
|
||||
typename detail::has_ostream_operator<const T&>::type{});
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const bool& value) {
|
||||
stream << std::boolalpha << value;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const char& value) {
|
||||
const bool printable = value >= 0x20 && value <= 0x7E;
|
||||
|
||||
if (printable) {
|
||||
stream << "'" << value << "'";
|
||||
} else {
|
||||
stream << "'\\x" << std::setw(2) << std::setfill('0') << std::hex
|
||||
<< std::uppercase << (0xFF & value) << "'";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline bool pretty_print(std::ostream& stream, P* const& value) {
|
||||
if (value == nullptr) {
|
||||
stream << "nullptr";
|
||||
} else {
|
||||
stream << value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename Deleter>
|
||||
inline bool pretty_print(std::ostream& stream,
|
||||
std::unique_ptr<T, Deleter>& value) {
|
||||
pretty_print(stream, value.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
// template <typename T>
|
||||
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value) {
|
||||
// pretty_print(stream, value.get());
|
||||
// stream << " (use_count = " << value.use_count() << ")";
|
||||
// return true;
|
||||
// }
|
||||
|
||||
template <size_t N>
|
||||
inline bool pretty_print(std::ostream& stream, const char (&value)[N]) {
|
||||
stream << value;
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool pretty_print(std::ostream& stream, const char* const& value) {
|
||||
stream << '"' << value << '"';
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t Idx>
|
||||
struct pretty_print_tuple {
|
||||
template <typename... Ts>
|
||||
static void print(std::ostream& stream, const std::tuple<Ts...>& tuple) {
|
||||
pretty_print_tuple<Idx - 1>::print(stream, tuple);
|
||||
stream << ", ";
|
||||
pretty_print(stream, std::get<Idx>(tuple));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct pretty_print_tuple<0> {
|
||||
template <typename... Ts>
|
||||
static void print(std::ostream& stream, const std::tuple<Ts...>& tuple) {
|
||||
pretty_print(stream, std::get<0>(tuple));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
inline bool pretty_print(std::ostream& stream, const std::tuple<Ts...>& value) {
|
||||
stream << "{";
|
||||
pretty_print_tuple<sizeof...(Ts) - 1>::print(stream, value);
|
||||
stream << "}";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool pretty_print(std::ostream& stream, const std::tuple<>&) {
|
||||
stream << "{}";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool pretty_print(std::ostream& stream, const time&) {
|
||||
using namespace std::chrono;
|
||||
|
||||
const auto now = system_clock::now();
|
||||
const auto us =
|
||||
duration_cast<microseconds>(now.time_since_epoch()).count() % 1000000;
|
||||
const auto hms = system_clock::to_time_t(now);
|
||||
const std::tm* tm = std::localtime(&hms);
|
||||
stream << "current time = " << std::put_time(tm, "%H:%M:%S") << '.'
|
||||
<< std::setw(6) << std::setfill('0') << us;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Converts decimal integer to binary string
|
||||
template <typename T>
|
||||
std::string decimalToBinary(T n) {
|
||||
const size_t length = 8 * sizeof(T);
|
||||
std::string toRet;
|
||||
toRet.resize(length);
|
||||
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
const auto bit_at_index_i = static_cast<char>((n >> i) & 1);
|
||||
toRet[length - 1 - i] = bit_at_index_i + '0';
|
||||
}
|
||||
|
||||
return toRet;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool pretty_print(std::ostream& stream,
|
||||
const print_formatted<T>& value) {
|
||||
if (value.inner < 0) {
|
||||
stream << "-";
|
||||
}
|
||||
stream << value.prefix();
|
||||
|
||||
// Print using setbase
|
||||
if (value.base != 2) {
|
||||
stream << std::setw(sizeof(T)) << std::setfill('0')
|
||||
<< std::setbase(value.base) << std::uppercase;
|
||||
|
||||
if (value.inner >= 0) {
|
||||
// The '+' sign makes sure that a uint_8 is printed as a number
|
||||
stream << +value.inner;
|
||||
} else {
|
||||
using unsigned_type = typename std::make_unsigned<T>::type;
|
||||
stream << +(static_cast<unsigned_type>(-(value.inner + 1)) + 1);
|
||||
}
|
||||
} else {
|
||||
// Print for binary
|
||||
if (value.inner >= 0) {
|
||||
stream << decimalToBinary(value.inner);
|
||||
} else {
|
||||
using unsigned_type = typename std::make_unsigned<T>::type;
|
||||
stream << decimalToBinary<unsigned_type>(
|
||||
static_cast<unsigned_type>(-(value.inner + 1)) + 1);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool pretty_print(std::ostream& stream, const print_type<T>&) {
|
||||
stream << type_name<T>();
|
||||
|
||||
stream << " [sizeof: " << sizeof(T) << " byte, ";
|
||||
|
||||
stream << "trivial: ";
|
||||
if (std::is_trivial<T>::value) {
|
||||
stream << "yes";
|
||||
} else {
|
||||
stream << "no";
|
||||
}
|
||||
|
||||
stream << ", standard layout: ";
|
||||
if (std::is_standard_layout<T>::value) {
|
||||
stream << "yes";
|
||||
} else {
|
||||
stream << "no";
|
||||
}
|
||||
stream << "]";
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename Enum>
|
||||
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
|
||||
pretty_print(std::ostream& stream, Enum const& value) {
|
||||
using UnderlyingType = typename std::underlying_type<Enum>::type;
|
||||
stream << static_cast<UnderlyingType>(value);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const std::string& value) {
|
||||
stream << '"' << value << '"';
|
||||
return true;
|
||||
}
|
||||
|
||||
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||
|
||||
inline bool pretty_print(std::ostream& stream, const std::string_view& value) {
|
||||
stream << '"' << std::string(value) << '"';
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T1, typename T2>
|
||||
inline bool pretty_print(std::ostream& stream, const std::pair<T1, T2>& value) {
|
||||
stream << "{";
|
||||
pretty_print(stream, value.first);
|
||||
stream << ", ";
|
||||
pretty_print(stream, value.second);
|
||||
stream << "}";
|
||||
return true;
|
||||
}
|
||||
|
||||
#if DBG_MACRO_CXX_STANDARD >= 17
|
||||
|
||||
template <typename T>
|
||||
inline bool pretty_print(std::ostream& stream, const std::optional<T>& value) {
|
||||
if (value) {
|
||||
stream << '{';
|
||||
pretty_print(stream, *value);
|
||||
stream << '}';
|
||||
} else {
|
||||
stream << "nullopt";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
inline bool pretty_print(std::ostream& stream,
|
||||
const std::variant<Ts...>& value) {
|
||||
stream << "{";
|
||||
std::visit([&stream](auto&& arg) { pretty_print(stream, arg); }, value);
|
||||
stream << "}";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename Container>
|
||||
inline typename std::enable_if<detail::is_container<const Container&>::value,
|
||||
bool>::type
|
||||
pretty_print(std::ostream& stream, const Container& value) {
|
||||
stream << "{";
|
||||
const size_t size = detail::size(value);
|
||||
const size_t n = std::min(size_t{10}, size);
|
||||
size_t i = 0;
|
||||
using std::begin;
|
||||
using std::end;
|
||||
for (auto it = begin(value); it != end(value) && i < n; ++it, ++i) {
|
||||
pretty_print(stream, *it);
|
||||
if (i != n - 1) {
|
||||
stream << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
if (size > n) {
|
||||
stream << ", ...";
|
||||
stream << " size:" << size;
|
||||
}
|
||||
|
||||
stream << "}";
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... U>
|
||||
struct last {
|
||||
using type = typename last<U...>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct last<T> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename... T>
|
||||
using last_t = typename last<T...>::type;
|
||||
|
||||
class DebugOutput {
|
||||
public:
|
||||
// Helper alias to avoid obscure type `const char* const*` in signature.
|
||||
using expr_t = const char*;
|
||||
|
||||
DebugOutput(const char* filepath, int line, const char* function_name)
|
||||
: m_use_colorized_output(isColorizedOutputEnabled()) {
|
||||
std::string path = filepath;
|
||||
const std::size_t path_length = path.length();
|
||||
if (path_length > MAX_PATH_LENGTH) {
|
||||
path = ".." + path.substr(path_length - MAX_PATH_LENGTH, MAX_PATH_LENGTH);
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << ansi(ANSI_DEBUG) << "[" << path << ":" << line << " ("
|
||||
<< function_name << ")] " << ansi(ANSI_RESET);
|
||||
m_location = ss.str();
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
auto print(std::initializer_list<expr_t> exprs,
|
||||
std::initializer_list<std::string> types,
|
||||
T&&... values) -> last_t<T...> {
|
||||
if (exprs.size() != sizeof...(values)) {
|
||||
std::cerr
|
||||
<< m_location << ansi(ANSI_WARN)
|
||||
<< "The number of arguments mismatch, please check unprotected comma"
|
||||
<< ansi(ANSI_RESET) << std::endl;
|
||||
}
|
||||
return print_impl(exprs.begin(), types.begin(), std::forward<T>(values)...);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
T&& print_impl(const expr_t* expr, const std::string* type, T&& value) {
|
||||
const T& ref = value;
|
||||
std::stringstream stream_value;
|
||||
const bool print_expr_and_type = pretty_print(stream_value, ref);
|
||||
|
||||
std::stringstream output;
|
||||
output << m_location;
|
||||
if (print_expr_and_type) {
|
||||
output << ansi(ANSI_EXPRESSION) << *expr << ansi(ANSI_RESET) << " = ";
|
||||
}
|
||||
output << ansi(ANSI_VALUE) << stream_value.str() << ansi(ANSI_RESET);
|
||||
if (print_expr_and_type) {
|
||||
output << " (" << ansi(ANSI_TYPE) << *type << ansi(ANSI_RESET) << ")";
|
||||
}
|
||||
output << std::endl;
|
||||
std::cerr << output.str();
|
||||
|
||||
return std::forward<T>(value);
|
||||
}
|
||||
|
||||
template <typename T, typename... U>
|
||||
auto print_impl(const expr_t* exprs,
|
||||
const std::string* types,
|
||||
T&& value,
|
||||
U&&... rest) -> last_t<T, U...> {
|
||||
print_impl(exprs, types, std::forward<T>(value));
|
||||
return print_impl(exprs + 1, types + 1, std::forward<U>(rest)...);
|
||||
}
|
||||
|
||||
const char* ansi(const char* code) const {
|
||||
if (m_use_colorized_output) {
|
||||
return code;
|
||||
} else {
|
||||
return ANSI_EMPTY;
|
||||
}
|
||||
}
|
||||
|
||||
const bool m_use_colorized_output;
|
||||
|
||||
std::string m_location;
|
||||
|
||||
static constexpr std::size_t MAX_PATH_LENGTH = 20;
|
||||
|
||||
static constexpr const char* const ANSI_EMPTY = "";
|
||||
static constexpr const char* const ANSI_DEBUG = "\x1b[02m";
|
||||
static constexpr const char* const ANSI_WARN = "\x1b[33m";
|
||||
static constexpr const char* const ANSI_EXPRESSION = "\x1b[36m";
|
||||
static constexpr const char* const ANSI_VALUE = "\x1b[01m";
|
||||
static constexpr const char* const ANSI_TYPE = "\x1b[32m";
|
||||
static constexpr const char* const ANSI_RESET = "\x1b[0m";
|
||||
};
|
||||
|
||||
// Identity function to suppress "-Wunused-value" warnings in DBG_MACRO_DISABLE
|
||||
// mode
|
||||
template <typename T>
|
||||
T&& identity(T&& t) {
|
||||
return std::forward<T>(t);
|
||||
}
|
||||
|
||||
template <typename T, typename... U>
|
||||
auto identity(T&&, U&&... u) -> last_t<U...> {
|
||||
return identity(std::forward<U>(u)...);
|
||||
}
|
||||
|
||||
} // namespace dbg
|
||||
|
||||
#ifndef DBG_MACRO_DISABLE
|
||||
|
||||
// Force expanding argument with commas for MSVC, ref:
|
||||
// https://stackoverflow.com/questions/35210637/macro-expansion-argument-with-commas
|
||||
// Note that "args" should be a tuple with parentheses, such as "(e1, e2, ...)".
|
||||
#define DBG_IDENTITY(x) x
|
||||
#define DBG_CALL(fn, args) DBG_IDENTITY(fn args)
|
||||
|
||||
#define DBG_CAT_IMPL(_1, _2) _1##_2
|
||||
#define DBG_CAT(_1, _2) DBG_CAT_IMPL(_1, _2)
|
||||
|
||||
#define DBG_16TH_IMPL(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
|
||||
_14, _15, _16, ...) \
|
||||
_16
|
||||
#define DBG_16TH(args) DBG_CALL(DBG_16TH_IMPL, args)
|
||||
#define DBG_NARG(...) \
|
||||
DBG_16TH((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||
|
||||
// DBG_VARIADIC_CALL(fn, data, e1, e2, ...) => fn_N(data, (e1, e2, ...))
|
||||
#define DBG_VARIADIC_CALL(fn, data, ...) \
|
||||
DBG_CAT(fn##_, DBG_NARG(__VA_ARGS__))(data, (__VA_ARGS__))
|
||||
|
||||
// (e1, e2, e3, ...) => e1
|
||||
#define DBG_HEAD_IMPL(_1, ...) _1
|
||||
#define DBG_HEAD(args) DBG_CALL(DBG_HEAD_IMPL, args)
|
||||
|
||||
// (e1, e2, e3, ...) => (e2, e3, ...)
|
||||
#define DBG_TAIL_IMPL(_1, ...) (__VA_ARGS__)
|
||||
#define DBG_TAIL(args) DBG_CALL(DBG_TAIL_IMPL, args)
|
||||
|
||||
#define DBG_MAP_1(fn, args) DBG_CALL(fn, args)
|
||||
#define DBG_MAP_2(fn, args) fn(DBG_HEAD(args)), DBG_MAP_1(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_3(fn, args) fn(DBG_HEAD(args)), DBG_MAP_2(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_4(fn, args) fn(DBG_HEAD(args)), DBG_MAP_3(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_5(fn, args) fn(DBG_HEAD(args)), DBG_MAP_4(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_6(fn, args) fn(DBG_HEAD(args)), DBG_MAP_5(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_7(fn, args) fn(DBG_HEAD(args)), DBG_MAP_6(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_8(fn, args) fn(DBG_HEAD(args)), DBG_MAP_7(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_9(fn, args) fn(DBG_HEAD(args)), DBG_MAP_8(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_10(fn, args) fn(DBG_HEAD(args)), DBG_MAP_9(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_11(fn, args) fn(DBG_HEAD(args)), DBG_MAP_10(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_12(fn, args) fn(DBG_HEAD(args)), DBG_MAP_11(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_13(fn, args) fn(DBG_HEAD(args)), DBG_MAP_12(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_14(fn, args) fn(DBG_HEAD(args)), DBG_MAP_13(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_15(fn, args) fn(DBG_HEAD(args)), DBG_MAP_14(fn, DBG_TAIL(args))
|
||||
#define DBG_MAP_16(fn, args) fn(DBG_HEAD(args)), DBG_MAP_15(fn, DBG_TAIL(args))
|
||||
|
||||
// DBG_MAP(fn, e1, e2, e3, ...) => fn(e1), fn(e2), fn(e3), ...
|
||||
#define DBG_MAP(fn, ...) DBG_VARIADIC_CALL(DBG_MAP, fn, __VA_ARGS__)
|
||||
|
||||
#define DBG_STRINGIFY_IMPL(x) #x
|
||||
#define DBG_STRINGIFY(x) DBG_STRINGIFY_IMPL(x)
|
||||
|
||||
#define DBG_TYPE_NAME(x) dbg::type_name<decltype(x)>()
|
||||
|
||||
#define dbg(...) \
|
||||
dbg::DebugOutput(__FILE__, __LINE__, __func__) \
|
||||
.print({DBG_MAP(DBG_STRINGIFY, __VA_ARGS__)}, \
|
||||
{DBG_MAP(DBG_TYPE_NAME, __VA_ARGS__)}, __VA_ARGS__)
|
||||
#else
|
||||
#define dbg(...) dbg::identity(__VA_ARGS__)
|
||||
#endif // DBG_MACRO_DISABLE
|
||||
|
||||
#endif // DBG_MACRO_DBG_H
|
|
@ -0,0 +1,156 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "expr.h"
|
||||
#include "iterator_table.h"
|
||||
#include "routine.h"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class Formula {
|
||||
public:
|
||||
Expr root;
|
||||
const int bfsDepth;
|
||||
|
||||
public:
|
||||
Formula(Expr _root, int _bfsDepth) : root(_root), bfsDepth(_bfsDepth) {}
|
||||
string toReadable() const;
|
||||
friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
|
||||
ios << expr.toReadable();
|
||||
return ios;
|
||||
}
|
||||
bool isVariable() const { return as<VarNode>(root) != nullptr; }
|
||||
};
|
||||
|
||||
class MultiFormulas {
|
||||
public:
|
||||
VecExpr roots;
|
||||
const int bfsDepth;
|
||||
|
||||
public:
|
||||
MultiFormulas(VecExpr roots, int _bfsDepth)
|
||||
: roots(roots), bfsDepth(_bfsDepth) {}
|
||||
// string toReadable() const;
|
||||
// friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
|
||||
// ios << expr.toReadable();
|
||||
// return ios;
|
||||
// }
|
||||
};
|
||||
|
||||
class Derivator {
|
||||
public:
|
||||
enum class LogMode { Normal, DumpFristCandiate, NoLog };
|
||||
enum class PassMode { Debug, Full };
|
||||
|
||||
private:
|
||||
list<Formula> candidates;
|
||||
const int maxDepth;
|
||||
int nIteratorNames = 0;
|
||||
int nTensorNames = 0;
|
||||
vector<vector<int>> rulesOverall;
|
||||
enum class Strategy { DFS, Rule, RuleAndDFS } searchStrategy;
|
||||
LogMode logMode;
|
||||
PassMode passMode;
|
||||
bool enableEquivalenceCheck = false;
|
||||
string logFnPrefix;
|
||||
const bool enableHashPruning;
|
||||
int searchedMaxDepth = 0;
|
||||
RoutineType targetOp = RoutineType::NoneType;
|
||||
map<int, vector<Var>> substituteRules;
|
||||
|
||||
vector<int> cntAppliedRules;
|
||||
int cntRule3 = 0;
|
||||
std::unordered_set<HashType> visited;
|
||||
VecExpr intermediateStates;
|
||||
vector<string> ruleStates, ruleMsgs;
|
||||
int cntStates = 0; // the number of intermediate states
|
||||
int searchState = 0; // search state in guided search
|
||||
|
||||
public:
|
||||
Derivator(int maxDepth = 8, bool enableHashPruning = true,
|
||||
LogMode mode = LogMode::NoLog,
|
||||
PassMode passMode = PassMode::Debug);
|
||||
void search(Formula &origin, int depth);
|
||||
void ruleBasedDFS(Formula &origin, int depth, vector<int> _rules,
|
||||
map<int, vector<Var>> _substituteRules = {},
|
||||
bool searchAfterRules = false);
|
||||
void guidedSearch(Formula &origin, int depth);
|
||||
void print();
|
||||
int getNumCandidates() const { return candidates.size(); }
|
||||
const auto &getCandidates() const { return candidates; }
|
||||
void appendCanddiate(const Tensor &tensor, int depth);
|
||||
int getSearchedMaxDepth() const { return searchedMaxDepth; };
|
||||
bool stageCombination(MultiFormulas &origin, int depth);
|
||||
bool checkOOB(const RangeOp &rangeOp, bool halt = true);
|
||||
|
||||
string newTensorName();
|
||||
Var getNewVar();
|
||||
|
||||
Expr mergeMemboundStages(VecExpr stages);
|
||||
|
||||
private:
|
||||
void dfs(Formula &origin, int depth);
|
||||
void ruleBasedDerivate(Formula &origin, int depth);
|
||||
|
||||
void rule1VariableSplit(Formula &origin, int depth, Expr &rCur);
|
||||
void rule2VariableMerging(Formula &origin, int depth, Expr &rCur);
|
||||
void rule3StageSplit(Formula &origin, int dfsDepth, Expr &rCur);
|
||||
void rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
|
||||
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
|
||||
bool mergeStageWithCalc = false);
|
||||
void rule6KenerlMatching(Formula &origin, int depth, Expr &rCur);
|
||||
void rule7DLT(Formula &origin, int depth, Expr &rCur);
|
||||
// Rule 8: guidedDLT
|
||||
void rule8GuidedDLT(Formula &origin, int depth, Expr &rCur);
|
||||
void rule9RangeMagnify(Formula &origin, int depth, Expr &rCur);
|
||||
void rule90TwoStageElementWise(Formula &origin, int depth, Expr &rCur);
|
||||
void rule91MergeStagesWithSum(Formula &origin, int depth, Expr &rCur);
|
||||
/**
|
||||
* @brief For searchState=2, wrap the RangeOp to add offset, if the boundary
|
||||
* does not start from 0. Then match the inner offset RangeOp.
|
||||
*/
|
||||
void matchComputationKernel(Formula &origin, int depth, Expr &rcur);
|
||||
/**
|
||||
* @brief For searchState=3, the Formula must be a MemBound kernel?
|
||||
*/
|
||||
void matchMemBoundKernel(Formula &origin, int depth, Expr &rcur);
|
||||
|
||||
/**
|
||||
* @brief Check the equivalence for exprs in intermediateStates.
|
||||
*/
|
||||
void checkDerivationEquivalence();
|
||||
|
||||
public:
|
||||
void pushIntermediateState(const Expr &expr);
|
||||
void pushRuleState(const string &state);
|
||||
void pushRuleMsg(const string &state);
|
||||
void popIntermediateState();
|
||||
void popRuleState();
|
||||
void popRuleMsg();
|
||||
// void pushTransformInfo(const Expr &expr, const string &state,
|
||||
// const string &msg);
|
||||
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur);
|
||||
|
||||
RoutineType getTargetOp();
|
||||
void setTargetOp(RoutineType _targetOp);
|
||||
|
||||
int getSearchState();
|
||||
void setSearchState(int _searchState);
|
||||
int getNumIntermediateStates();
|
||||
void printStatistics();
|
||||
void printIntermediateStates();
|
||||
/**
|
||||
* @brief Enable dumping the first results. Verification is enabled so it is
|
||||
* slow.
|
||||
*
|
||||
* @param _logFnPrefix Prefix of output filename
|
||||
*/
|
||||
void setDumpFirstSuccess(const string &_logFnPrefix);
|
||||
void setEquivalenceCheck();
|
||||
PassMode getPassMode();
|
||||
LogMode getLogMode();
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,53 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "expr.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// enum class DLTType { Split, Merge, Reorder };
|
||||
|
||||
struct DLTOperation {
|
||||
// DLTType type;
|
||||
virtual ~DLTOperation() {}
|
||||
};
|
||||
struct DLTSplit : DLTOperation {
|
||||
int dim, factor;
|
||||
DLTSplit(int _dim, int _factor) : dim(_dim), factor(_factor) {}
|
||||
};
|
||||
struct DLTMerge : DLTOperation {
|
||||
int dim0, dim1;
|
||||
DLTMerge(int _dim0, int _dim1) : dim0(_dim0), dim1(_dim1) {}
|
||||
};
|
||||
struct DLTReorder : DLTOperation {
|
||||
vector<int> dims;
|
||||
DLTReorder(vector<int> _dims) : dims(_dims) {}
|
||||
};
|
||||
|
||||
class DLT {
|
||||
vector<Ref<DLTOperation>> ops;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief dim -> (dim/factor, factor)
|
||||
*/
|
||||
void split(int dim, int factor);
|
||||
/**
|
||||
* @brief Merge dim1 into dim0 -> (dim0, dim1)
|
||||
*/
|
||||
void merge(int dim0, int dim1);
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param dims dims[new_dim]=old_dim
|
||||
*/
|
||||
void reorder(vector<int> dims);
|
||||
optional<Expr> apply(const RangeOp &rangeOp, const Subscript &subscript,
|
||||
string newTensorName);
|
||||
|
||||
private:
|
||||
optional<pair<Expr, Expr>> splitIndex(Expr expr, int factor,
|
||||
RangeOp rangeOp);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,416 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "ref.h"
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class ExprNode;
|
||||
class VarNode;
|
||||
class TensorNode;
|
||||
class OperatorNode;
|
||||
class RangeOpNode;
|
||||
class SubscriptNode;
|
||||
class BinaryOpNode;
|
||||
class ConstantNode;
|
||||
class FuncNode;
|
||||
using Expr = Ref<ExprNode>;
|
||||
using Var = Ref<VarNode>;
|
||||
using Tensor = Ref<TensorNode>;
|
||||
using Operator = Ref<OperatorNode>;
|
||||
using RangeOp = Ref<RangeOpNode>;
|
||||
using Subscript = Ref<SubscriptNode>;
|
||||
using BinaryOp = Ref<BinaryOpNode>;
|
||||
using Constant = Ref<ConstantNode>;
|
||||
using Func = Ref<FuncNode>;
|
||||
|
||||
class RoutineNode;
|
||||
using Routine = Ref<RoutineNode>;
|
||||
enum class RoutineType {
|
||||
NoneType = 100,
|
||||
MatmulNodeType,
|
||||
ConvNodeType,
|
||||
G2bmmNodeType,
|
||||
GbmmNodeType,
|
||||
ElementWiseNodeType // unmatchable
|
||||
};
|
||||
constexpr inline int MatchableRoutineTypeCnt = 4;
|
||||
constexpr inline int RoutineTypeCnt = MatchableRoutineTypeCnt + 1;
|
||||
inline RoutineType idToRoutineType(int i) {
|
||||
return static_cast<RoutineType>(i + 1 +
|
||||
static_cast<int>(RoutineType::NoneType));
|
||||
}
|
||||
inline int routineTypeToId(const RoutineType &routineType) {
|
||||
return static_cast<int>(routineType) -
|
||||
static_cast<int>(RoutineType::NoneType) - 1;
|
||||
}
|
||||
|
||||
using VecExpr = vector<Expr>;
|
||||
|
||||
// common data structure
|
||||
using Iterator = Var; // RE: remove this alias
|
||||
template <typename T, typename U> using PtrMap = std::map<T, U, ptr_less<T>>;
|
||||
template <typename T, typename U>
|
||||
// When keys are pointers, compare keys according to its value instead of
|
||||
// address Specially, the name of Var are compared due to the overload of op=
|
||||
// and hash.
|
||||
using PtrUmap = std::unordered_map<T, U, ptr_hash<T>, ptr_equal<T>>;
|
||||
template <typename T>
|
||||
using PtrUset = std::unordered_set<T, ptr_hash<T>, ptr_equal<T>>;
|
||||
using Appearance = PtrMap<Var, vector<pair<Tensor, int>>>;
|
||||
using StrideTable =
|
||||
PtrMap<Var, vector<tuple<TensorNode *, int, int>>>; // Tensor, dim, stride
|
||||
|
||||
// AST node opeartor
|
||||
bool operator==(const Var &lhs, const string &rhs);
|
||||
bool operator==(const string &lhs, const Var &rhs);
|
||||
Expr operator+(const Expr &lhs, const Expr &rhs);
|
||||
BinaryOp operator-(const Expr &lhs, const Expr &rhs);
|
||||
BinaryOp operator*(const Expr &lhs, const Expr &rhs);
|
||||
BinaryOp operator/(const Expr &lhs, const Expr &rhs);
|
||||
BinaryOp operator%(const Expr &lhs, const Expr &rhs);
|
||||
|
||||
Expr operator+(const Expr &lhs, const int &rhs);
|
||||
Expr operator+(const int &lhs, const Expr &rhs);
|
||||
Expr operator-(const Expr &lhs, const int &rhs);
|
||||
Expr operator-(const int &lhs, const Expr &rhs);
|
||||
Expr operator*(const Expr &lhs, const int &rhs);
|
||||
Expr operator*(const int &lhs, const Expr &rhs);
|
||||
Expr operator%(const Expr &lhs, const int rhs);
|
||||
Expr operator/(const Expr &lhs, const int rhs);
|
||||
|
||||
string serializeVec(vector<Expr> v);
|
||||
string serializeVec(vector<Var> v);
|
||||
template <typename T> inline string serializeVec(vector<T> v) {
|
||||
if (v.empty())
|
||||
return "[]";
|
||||
return "[" +
|
||||
std::accumulate(
|
||||
v.begin() + 1, v.end(), to_string(v[0]),
|
||||
[](const string &a, int b) { return a + ',' + to_string(b); }) +
|
||||
"]";
|
||||
}
|
||||
|
||||
// For RTTI and visitor pattern
|
||||
enum class NodeType {
|
||||
ConstantNodeType,
|
||||
BinaryOpNodeType,
|
||||
RangeOpNodeType,
|
||||
SubscriptNodeType,
|
||||
TensorNodeType,
|
||||
VarNodeType,
|
||||
FuncNodeType
|
||||
};
|
||||
|
||||
enum class FuncType { Relu, Tanh };
|
||||
|
||||
#define DEFINE_GETTYPE(CLASS) \
|
||||
NodeType getType() const override { return NodeType::CLASS##Type; }
|
||||
|
||||
class ExprNode {
|
||||
public:
|
||||
virtual ~ExprNode() {}
|
||||
ExprNode &operator=(const ExprNode &rhs) = delete;
|
||||
|
||||
virtual HashType hash() const = 0; // RE: remove?
|
||||
virtual string toReadable() const = 0;
|
||||
friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr);
|
||||
|
||||
virtual NodeType getType() const = 0;
|
||||
};
|
||||
|
||||
class VarNode : public ExprNode {
|
||||
std::string name;
|
||||
|
||||
public:
|
||||
VarNode(std::string _name) : name(_name){};
|
||||
virtual ~VarNode() {}
|
||||
DEFINE_GETTYPE(VarNode);
|
||||
|
||||
const std::string &getName() const { return name; }
|
||||
HashType hash() const override { return genhash(name); };
|
||||
string toReadable() const override { return name; };
|
||||
bool equal(const Var &rhs) const { return name == rhs->getName(); }
|
||||
bool neq(const Var &rhs) const { return !equal(rhs); }
|
||||
bool less(const Var &rhs) const { return name < rhs->getName(); }
|
||||
bool equal(const string &rhs) const { return name == rhs; }
|
||||
bool operator==(const VarNode &rhs) const { return name == rhs.getName(); }
|
||||
bool operator<(const VarNode &rhs) const { return name < rhs.getName(); }
|
||||
};
|
||||
|
||||
enum class TensorType { Input, Weight, Intermediate };
|
||||
|
||||
class TensorNode : public ExprNode {
|
||||
string name;
|
||||
vector<int> shape, paddings;
|
||||
TensorType type;
|
||||
Routine source; // if NO source, then this is a input/weight tensor
|
||||
|
||||
public:
|
||||
TensorNode(string _name, vector<int> _shape, vector<int> _paddings = {},
|
||||
Routine _source = nullptr);
|
||||
virtual ~TensorNode() {}
|
||||
DEFINE_GETTYPE(TensorNode);
|
||||
|
||||
bool operator==(const string &rhs) { return name == rhs; }
|
||||
friend bool operator==(const string &lhs, const TensorNode &rhs) {
|
||||
return lhs == rhs.name;
|
||||
}
|
||||
|
||||
HashType hash() const override { return genhash(name); }
|
||||
string toReadable() const override;
|
||||
string toOutputShape() const;
|
||||
const std::string &getName() const { return name; }
|
||||
std::vector<int> &getPadding() { return paddings; }
|
||||
int getPadding(int i) const { return paddings[i]; }
|
||||
const vector<int> &getPaddings() const { return paddings; }
|
||||
void setPadding(int i, int p) { paddings[i] = p; }
|
||||
const vector<int> &getShape() const { return shape; }
|
||||
int getShape(int i) const { return shape[i]; }
|
||||
int64_t getSize() const;
|
||||
int getDims() const { return shape.size(); }
|
||||
const Routine &getSource() const { return source; }
|
||||
int getData(const Ref<vector<int>> &data, const vector<int> &idx);
|
||||
size_t getOffset(const vector<int> &idx);
|
||||
};
|
||||
|
||||
enum class OpType { Range, Add, Mul, Div, Mod, Sub };
|
||||
const char opSymbols[] = "#+*/%-";
|
||||
|
||||
class OperatorNode : public ExprNode {
|
||||
protected:
|
||||
const OpType opType;
|
||||
VecExpr subExprs;
|
||||
|
||||
public:
|
||||
OperatorNode(OpType _opType) : opType(_opType){};
|
||||
OperatorNode(OpType _opType, VecExpr _subExprs)
|
||||
: opType(_opType), subExprs(_subExprs){};
|
||||
|
||||
int getSubExprsNum() { return subExprs.size(); };
|
||||
const VecExpr &getSubExprs() { return subExprs; }
|
||||
const Expr &getSubExprs(int i) const { return subExprs[i]; }
|
||||
OpType getOpType() const { return opType; };
|
||||
void setOperands(int i, Expr e) { subExprs[i] = e; }
|
||||
};
|
||||
|
||||
using Range = pair<int, int>;
|
||||
using VarRangePair = pair<Var, Range>;
|
||||
inline int getLength(const Range &range) { return range.second - range.first; }
|
||||
struct IterationType {
|
||||
enum { Loop, Sum };
|
||||
constexpr static int NumIterationType = 2;
|
||||
};
|
||||
class RangeOpNode : public OperatorNode {
|
||||
public:
|
||||
enum { Summand, END_POS };
|
||||
constexpr static int Loop = IterationType::Loop;
|
||||
constexpr static int Sum = IterationType::Sum;
|
||||
|
||||
private:
|
||||
vector<VarRangePair> vars[IterationType::NumIterationType];
|
||||
vector<int> paddings;
|
||||
|
||||
public:
|
||||
RangeOpNode(Expr _summand) : OperatorNode(OpType::Range, {_summand}){};
|
||||
RangeOpNode(const vector<VarRangePair> &_loopIters,
|
||||
const vector<VarRangePair> &_sumIters, Expr _summand,
|
||||
const vector<int> &paddings)
|
||||
: OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters},
|
||||
paddings(paddings){};
|
||||
DEFINE_GETTYPE(RangeOpNode);
|
||||
|
||||
virtual HashType hash() const override {
|
||||
nnet_unimplemented_halt();
|
||||
return 0;
|
||||
};
|
||||
string toReadable() const override;
|
||||
const Expr &getSummand() const { return subExprs[Summand]; }
|
||||
const vector<VarRangePair> &getVarRanges(int _index) const {
|
||||
return vars[_index];
|
||||
}
|
||||
const vector<VarRangePair> &getLoopVarRanges() const {
|
||||
return vars[IterationType::Loop];
|
||||
}
|
||||
const vector<VarRangePair> &getSumVarRanges() const {
|
||||
return vars[IterationType::Sum];
|
||||
}
|
||||
int getNumOutputDims() const;
|
||||
bool hasVar(int index, Var name) const;
|
||||
bool hasLoopVar(Var name) const { return hasVar(Loop, name); }
|
||||
bool hasSumVar(Var name) const { return hasVar(Sum, name); }
|
||||
bool hasLoopVar(string name) const {
|
||||
return hasVar(Loop, make_ref<VarNode>(name));
|
||||
}
|
||||
bool hasSumVar(string name) const {
|
||||
return hasVar(Sum, make_ref<VarNode>(name));
|
||||
}
|
||||
int getVarIndex(int type, string name);
|
||||
void setSummand(Expr e) { subExprs[Summand] = e; }
|
||||
void setLoopIterator(const vector<VarRangePair> &vecExpr) {
|
||||
vars[Loop] = vecExpr;
|
||||
}
|
||||
void setSumIterator(const vector<VarRangePair> &vecExpr) {
|
||||
vars[Sum] = vecExpr;
|
||||
}
|
||||
void setIterator(const vector<VarRangePair> &loop,
|
||||
const vector<VarRangePair> &sum) {
|
||||
setLoopIterator(loop);
|
||||
setSumIterator(sum);
|
||||
}
|
||||
|
||||
const VarRangePair &getVarRange(int _index, int i) const {
|
||||
return vars[_index][i];
|
||||
}
|
||||
const Var &getLoopVar(int i) const { return vars[Loop][i].first; }
|
||||
Range getRange(const Var &var) const;
|
||||
VarRangePair getVarRange(const Var &var) const;
|
||||
bool hasPaddings() const;
|
||||
int getPaddings(int dim) const;
|
||||
vector<int> getPaddings() const;
|
||||
void setPaddings(vector<int> _paddings);
|
||||
void setVarRange(int _index, int i, VarRangePair pair) {
|
||||
vars[_index][i] = pair;
|
||||
}
|
||||
int64_t getFlops() const;
|
||||
int64_t getInputSize(const RangeOp &self) const;
|
||||
int64_t getOutputSize() const;
|
||||
vector<int> getOutputShape() const;
|
||||
// Including paddings
|
||||
vector<Range> getOutputRanges() const;
|
||||
};
|
||||
|
||||
class BinaryOpNode : public OperatorNode {
|
||||
enum { LHS, RHS, END_POS };
|
||||
|
||||
public:
|
||||
BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs)
|
||||
: OperatorNode(_opType, {_lhs, _rhs}){};
|
||||
virtual ~BinaryOpNode() {}
|
||||
DEFINE_GETTYPE(BinaryOpNode);
|
||||
|
||||
virtual HashType hash() const override {
|
||||
return genhash((HashType)opType,
|
||||
genhash(subExprs[LHS]->hash(), subExprs[RHS]->hash()));
|
||||
};
|
||||
virtual string toReadable() const override;
|
||||
const Expr &getLhs() const { return getSubExprs(LHS); };
|
||||
const Expr &getRhs() const { return getSubExprs(RHS); };
|
||||
void setLhs(Expr e) { setOperands(LHS, e); };
|
||||
void setRhs(Expr e) { setOperands(RHS, e); };
|
||||
// If Var/constant, use this one
|
||||
optional<pair<Var, int>> getModDivParameter() const;
|
||||
// If (Var+constant)/constant, use this one
|
||||
pair<Expr, int> getModDivExpr() const;
|
||||
bool isSwapable() const;
|
||||
};
|
||||
|
||||
class ConstantNode : public ExprNode {
|
||||
int val;
|
||||
|
||||
public:
|
||||
ConstantNode(int _val) : val(_val){};
|
||||
ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){};
|
||||
virtual ~ConstantNode() {}
|
||||
DEFINE_GETTYPE(ConstantNode);
|
||||
|
||||
int getValue() const { return val; }
|
||||
virtual HashType hash() const override { return genhash(val, 6214587); };
|
||||
virtual string toReadable() const override {
|
||||
string ret;
|
||||
ret += std::to_string(val);
|
||||
return ret;
|
||||
};
|
||||
};
|
||||
|
||||
class SubscriptNode : public ExprNode {
|
||||
protected:
|
||||
Expr indexed;
|
||||
VecExpr subExprs;
|
||||
|
||||
public:
|
||||
SubscriptNode(Expr _indexed, vector<Expr> _subExprs) : subExprs(_subExprs) {
|
||||
setObject(_indexed);
|
||||
};
|
||||
DEFINE_GETTYPE(SubscriptNode);
|
||||
|
||||
virtual HashType hash() const override {
|
||||
nnet_unimplemented_continue();
|
||||
return -1;
|
||||
};
|
||||
virtual string toReadable() const override;
|
||||
|
||||
size_t getDims() const { return subExprs.size(); }
|
||||
const VecExpr &getIndex() const { return subExprs; }
|
||||
const Expr &getIndex(size_t i) const { return subExprs[i]; }
|
||||
void setIndex(size_t i, Expr e) { subExprs[i] = e; }
|
||||
Expr *getObjectPtr() { return &indexed; }
|
||||
Expr getObject() const { return indexed; }
|
||||
void setObject(Expr e);
|
||||
bool isRangeOpSubscripted() const;
|
||||
bool isTensorSubscripted() const { return !isRangeOpSubscripted(); }
|
||||
// Get the ranges of objects including paddings
|
||||
vector<Range> getObjectRangesWithPaddings() const;
|
||||
vector<Range> getObjectRangesWithoutPaddings() const;
|
||||
};
|
||||
|
||||
class FuncNode : public ExprNode {
|
||||
protected:
|
||||
Subscript object;
|
||||
FuncType funcType;
|
||||
|
||||
public:
|
||||
FuncNode(Expr object, FuncType funcType) : funcType(funcType) {
|
||||
setObject(object);
|
||||
}
|
||||
DEFINE_GETTYPE(FuncNode);
|
||||
|
||||
virtual HashType hash() const override {
|
||||
nnet_unimplemented_continue();
|
||||
return -1;
|
||||
};
|
||||
virtual string toReadable() const override;
|
||||
|
||||
const Subscript &getObject() const { return object; }
|
||||
void setObject(Expr e);
|
||||
|
||||
FuncType getFuncType() const { return funcType; }
|
||||
};
|
||||
|
||||
// Wrappers for type deduction
|
||||
Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts);
|
||||
RangeOp makeRangeOperator(const vector<VarRangePair> &_loopIters,
|
||||
const vector<VarRangePair> &_sumIters, Expr _summand,
|
||||
const vector<int> &paddings = {});
|
||||
Tensor makeTensor(const string &name, const vector<int> &shape,
|
||||
const vector<int> &paddings = {},
|
||||
const Routine &source = nullptr);
|
||||
|
||||
// Pretty output for dbg with shared_ptr
|
||||
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
|
||||
*_ = nullptr>
|
||||
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
|
||||
os << ((!a) ? string("nullptr") : a->toReadable());
|
||||
return os;
|
||||
}
|
||||
|
||||
// Pretty output for dbg with shared_ptr
|
||||
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
|
||||
*_ = nullptr>
|
||||
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
|
||||
os << ((!a) ? string("nullptr") : a->toReadable());
|
||||
return os;
|
||||
}
|
||||
#undef DEFINE_GETTYPE
|
||||
|
||||
} // namespace nnet
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<nnet::VarNode &> {
|
||||
size_t operator()(const nnet::VarNode &t) const {
|
||||
return std::hash<string>()(t.getName());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
|
@ -0,0 +1,234 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "expr.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
using PatternTensorMap = vector<Tensor>;
|
||||
using PatternIterRangeMap = PtrMap<Iterator, VarRangePair>;
|
||||
|
||||
enum class MismatchType {
|
||||
// Search required (undetermined)
|
||||
MoreVar,
|
||||
LessVar,
|
||||
StrideMismatch,
|
||||
// guided DLT (determined)
|
||||
DLMismatch,
|
||||
OutputDLMismatch,
|
||||
OutputDimismatch
|
||||
};
|
||||
struct Mismatch {
|
||||
MismatchType type;
|
||||
int bitmap; // Row ID of IT
|
||||
PtrMap<Iterator, Iterator>
|
||||
mappingIter_r; // For DLT mismatch, iters are mapped
|
||||
Mismatch(MismatchType _type, int _bitmap) : type(_type), bitmap(_bitmap) {}
|
||||
Mismatch(MismatchType _type, int _bitmap,
|
||||
PtrMap<Iterator, Iterator> _mappingIter_r)
|
||||
: type(_type), bitmap(_bitmap), mappingIter_r(_mappingIter_r) {}
|
||||
};
|
||||
class Pattern;
|
||||
class IteratorTable {
|
||||
protected:
|
||||
// using Appearance = map<string, vector<pair<Tensor, int>>>;
|
||||
// using StrideTable = map<TensorNode *, vector<tuple<string, int, int>>>;
|
||||
// // Var, dim, stride
|
||||
RangeOp rangeOp;
|
||||
// To real tensor
|
||||
// FIXME: redundent
|
||||
Appearance appearance;
|
||||
vector<Tensor> tensors; // original tensor sequence
|
||||
vector<Subscript> subscripts; // original subscripts sequence
|
||||
StrideTable strideTable; // TODO [Refactor]: rename strideTable
|
||||
PatternIterRangeMap iterToRange;
|
||||
|
||||
// mapping
|
||||
vector<int> tensorMap; // [index for tensors] -> tensorID in pattern
|
||||
PtrMap<Iterator, Iterator> iterMap; // [expr iter] -> pattern iter
|
||||
|
||||
// final data
|
||||
vector<vector<Iterator>> posTable; // [Tensor bitmap]=[Iterator]
|
||||
vector<vector<vector<Iterator>>>
|
||||
iterInTensorDim; // [tensorID][dimOfTensor]=[Iterator],
|
||||
// stride in each dim may be add
|
||||
|
||||
vector<vector<PtrMap<Iterator, int>>>
|
||||
strideInDim; // [tensorID][dimOfTensor][Iterator]=stride,
|
||||
// stride in each dim may be add
|
||||
|
||||
PtrMap<Iterator, vector<int>> strideInTensor; // [Iterator][tensorID]=stride
|
||||
|
||||
// final data: auxiliary data
|
||||
vector<int> tensorIDMap_r;
|
||||
PatternTensorMap tensorMap_r;
|
||||
PatternIterRangeMap iterToRange_r;
|
||||
|
||||
public:
|
||||
virtual ~IteratorTable() {}
|
||||
IteratorTable() {}
|
||||
IteratorTable(const IteratorTable &) = delete;
|
||||
[[nodiscard]] bool analyzeExpr(const RangeOp &rangeOp);
|
||||
// mapTensors
|
||||
void buildTable(const vector<int> &_tensorMap);
|
||||
void buildTableWithDefaultMap();
|
||||
/**
|
||||
* @brief Check whether the expression match a pattern. If not, return the
|
||||
* detailed reason for guided search.
|
||||
*
|
||||
* @param patternIT
|
||||
* @return vector<int> mismatched IT rows/tensors for guided DLT.
|
||||
*/
|
||||
vector<Mismatch> matchPatternIT(const Pattern &patternIT);
|
||||
void matchIterators();
|
||||
int getNumInputs() const { return tensors.size(); }
|
||||
int getNumTensors() const { return tensors.size() + 1; }
|
||||
int getNumRows() const { return 1 << getNumTensors(); }
|
||||
int getNumIterators() const { return strideTable.size(); }
|
||||
// vector<Tensor> tensorMap_r(
|
||||
// pattern.nInputs); // [pattern tensor ID] -> real tensor
|
||||
// map<string, VarRangePair> iterToRange_r; // [pattern iter] -> iter &
|
||||
// range
|
||||
auto getTables() const {
|
||||
return tuple(posTable, iterInTensorDim, strideInTensor);
|
||||
}
|
||||
const auto &getStrideInDim() const { return strideInDim; }
|
||||
vector<vector<Iterator>> getIterInTensorDim(int tensorID) const {
|
||||
return iterInTensorDim[tensorID];
|
||||
}
|
||||
const vector<Iterator> &getPosTable(int bitmap) const {
|
||||
return posTable[bitmap];
|
||||
}
|
||||
pair<PatternTensorMap, PatternIterRangeMap> getReverseMap() const;
|
||||
|
||||
int getStridesInTensor(Iterator iter, int tensorID) const;
|
||||
vector<int> getIterDimInTensor(int tensorID, const Iterator &iter) const;
|
||||
Tensor getTensor(int tensorID) const { return tensorMap_r[tensorID]; }
|
||||
Subscript getSubscript(int tensorID) const {
|
||||
return subscripts[tensorIDMap_r[tensorID]];
|
||||
}
|
||||
Range getIterRange(const Iterator &iter) const {
|
||||
return rangeOp->getRange(iter);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check strides of each iterators and there position in tensors.
|
||||
* Since many-to-many iterators matching exist, we take this procudure as a
|
||||
* seperate function to deal with different iterator mapping solution.
|
||||
*
|
||||
* @param patternIT
|
||||
* @param mappingIter_r
|
||||
* @return vector<Mismatch>
|
||||
*/
|
||||
vector<Mismatch>
|
||||
matchPatternITCheckStrides(const Pattern &patternIT,
|
||||
PtrMap<Iterator, Iterator> mappingIter_r);
|
||||
RangeOp getRangeOp() const;
|
||||
};
|
||||
|
||||
struct StrideConstraint {
|
||||
int tensorID;
|
||||
Var v0, v1;
|
||||
enum class Constraint { SAME, PROPOTIONAL } type;
|
||||
};
|
||||
|
||||
class Pattern : public IteratorTable {
|
||||
vector<StrideConstraint> strideConstraints;
|
||||
|
||||
public:
|
||||
virtual Expr
|
||||
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||
string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const = 0;
|
||||
/**
|
||||
* @brief Check whether all indexes only are a iterator
|
||||
*
|
||||
* @param tensorID
|
||||
*/
|
||||
bool isAllUniqueAccess(int tensorID) const;
|
||||
const auto &getStrideConstraints() const { return strideConstraints; };
|
||||
int calcPadding(const Tensor &tensor, int dim, Range rangeH, Range rangeR,
|
||||
int offset) const;
|
||||
};
|
||||
|
||||
class MatmulPattern : public Pattern {
|
||||
public:
|
||||
static const Pattern &getMatmulPattern();
|
||||
static pair<Expr, pair<Tensor, Tensor>> getExpr(bool transA, bool transB,
|
||||
int b, int m, int n, int k);
|
||||
|
||||
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||
string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||
};
|
||||
|
||||
class ConvPattern : public Pattern {
|
||||
private:
|
||||
static const Var n, c, h, w, f, r, s;
|
||||
|
||||
public:
|
||||
static const Pattern &getPattern();
|
||||
static Expr getExpr(Tensor A, Tensor K, int n, int c, int h, int w, int f,
|
||||
int r, int s);
|
||||
|
||||
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||
string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||
};
|
||||
|
||||
class ConvTransPattern : public Pattern {
|
||||
private:
|
||||
static const Var n, c, h, w, f, r, s;
|
||||
|
||||
public:
|
||||
static const Pattern &getPattern() = delete;
|
||||
static Expr getExpr(Tensor A, Tensor K, int N, int C, int H, int W, int F,
|
||||
int R, int S);
|
||||
|
||||
Expr
|
||||
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||
string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const override {
|
||||
nnet_unimplemented_halt();
|
||||
return nullptr;
|
||||
};
|
||||
};
|
||||
|
||||
class Sg2bmmPattern : public Pattern {
|
||||
private:
|
||||
static const Var b, m, w, k;
|
||||
|
||||
public:
|
||||
static const Pattern &getPattern();
|
||||
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int K,
|
||||
int W, int D);
|
||||
|
||||
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||
string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||
};
|
||||
|
||||
class LongformerGBMMPattern : public Pattern {
|
||||
private:
|
||||
static const Var b, m, w, n;
|
||||
|
||||
public:
|
||||
static const Pattern &getPattern();
|
||||
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int W,
|
||||
int K, int dilation);
|
||||
|
||||
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges,
|
||||
string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const override;
|
||||
};
|
||||
|
||||
const Pattern &getPattern(RoutineType targetOp);
|
||||
string getPatternName(RoutineType targetOp);
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,57 @@
|
|||
#pragma once
|
||||
#include "core/mutator.h"
|
||||
#include "nnet/expr.h"
|
||||
|
||||
#ifdef ABC
|
||||
|
||||
namespace infini {
|
||||
|
||||
class NMutator : public Mutator {
|
||||
private:
|
||||
// Suffix -N: NNet objects.
|
||||
// Suffix -T: tpm objects.
|
||||
// Map: NNet tensors -> tpm tensor.
|
||||
std::map<std::string, Tensor> inputsNameNToTensorT;
|
||||
enum class Mode { Normal, ToNaiveMembound, RuleBased } mode = Mode::Normal;
|
||||
const double bandwidth = double(200) * 1024 * 1024 * 1024;
|
||||
// If in RuleBased mode, use derivationRules in derivator
|
||||
const std::vector<int> derivationRules;
|
||||
|
||||
public:
|
||||
NMutator();
|
||||
NMutator(const std::vector<int> &derivationRules);
|
||||
~NMutator();
|
||||
|
||||
vector<Graph> run(const Graph &in_graph) override;
|
||||
void setToNaiveMembound();
|
||||
|
||||
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
|
||||
long long cntStates = 0;
|
||||
long long cntCandidates = 0;
|
||||
|
||||
private:
|
||||
int maxDepth = 8;
|
||||
nnet::Expr opToExpression(Operator op);
|
||||
void runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs);
|
||||
|
||||
/**
|
||||
* @brief Test helper. Converting a single OP to Membound Op for
|
||||
* corretness check.
|
||||
*/
|
||||
void runSingleOpToNaiveMembound(Graph in_graph,
|
||||
std::vector<Graph> &out_graphs);
|
||||
void runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs);
|
||||
Graph expressionToGraph(nnet::Expr expr, Graph in_graph);
|
||||
Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
|
||||
double memboundTime(ssize_t cnt);
|
||||
double memboundTime(const Shape &dims);
|
||||
|
||||
Graph transformTConv1x1(Operator op);
|
||||
Graph transformTConv3x3(Operator op);
|
||||
Graph transformDialtedConv(Operator op);
|
||||
Graph transformConv1x1(Operator op);
|
||||
Graph transformConv1xk(Operator op);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "expr.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class PermutationGenerator {
|
||||
vector<vector<Iterator>> from, to;
|
||||
vector<vector<size_t>> mapping;
|
||||
|
||||
public:
|
||||
PermutationGenerator(vector<vector<Iterator>> _from,
|
||||
vector<vector<Iterator>> _to);
|
||||
bool next();
|
||||
PtrMap<Iterator, Iterator> get() const;
|
||||
};
|
||||
|
||||
template <typename T> class SubsetGenerator {
|
||||
vector<T> elements;
|
||||
int n, bitmap;
|
||||
|
||||
public:
|
||||
SubsetGenerator(vector<T> elements, bool nonEmpty = 1)
|
||||
: elements(elements), n(elements.size()), bitmap((nonEmpty > 0)) {
|
||||
assert(n < 10);
|
||||
};
|
||||
bool next() { return ((++bitmap) < (1 << n) - 1); }
|
||||
vector<T> get() const {
|
||||
vector<T> ret;
|
||||
for (int i = 0; i < n; ++i)
|
||||
if (bitmap & (1 << i))
|
||||
ret.emplace_back(elements[i]);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,200 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include <functional> // hash
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
template <typename T> struct is_ref;
|
||||
|
||||
/**
|
||||
* Ref-counting pointer
|
||||
*
|
||||
* This class is thread-safe (For developers: concurrent accesses through
|
||||
* different `std::shared_ptr`s to the same object is already thread-safe, while
|
||||
* modifying the same `std::shared_ptr` is not. We never modify a `Ref`, so no
|
||||
* locks are needed. See https://en.cppreference.com/w/cpp/memory/shared_ptr)
|
||||
*/
|
||||
template <class T> class Ref {
|
||||
static_assert(is_ref<T>::value == false, "Ref should not be nested");
|
||||
|
||||
template <class U> friend class Ref;
|
||||
|
||||
std::shared_ptr<T> ptr_;
|
||||
|
||||
private:
|
||||
public:
|
||||
typedef T Object;
|
||||
|
||||
Ref() = default;
|
||||
// Ref(std::nullptr_t) : Ref() {}
|
||||
constexpr Ref(nullptr_t) noexcept : Ref() {}
|
||||
Ref(const Ref &) = default;
|
||||
Ref(Ref &&) = default;
|
||||
Ref(std::shared_ptr<T> &&ptr) : ptr_(std::move(ptr)) {}
|
||||
// Ref(const std::shared_ptr<T> &ptr) : ptr_(ptr) {}
|
||||
|
||||
// /// NO NOT USE THIS CONSTRUCTOR IN PUBLIC
|
||||
// /// It is public because Pybind11 needs it
|
||||
// Ref(T *ptr) : ptr_(ptr) {}
|
||||
|
||||
/**
|
||||
* Shared with any compatible references
|
||||
*/
|
||||
template <class U,
|
||||
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
|
||||
Ref(const Ref<U> &other) : ptr_(std::static_pointer_cast<T>(other.ptr_)) {}
|
||||
|
||||
template <class U,
|
||||
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
|
||||
Ref &operator=(const Ref<U> &other) {
|
||||
ptr_ = std::static_pointer_cast<T>(other.ptr_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Ref &operator=(const Ref &) = default;
|
||||
Ref &operator=(Ref &&) = default;
|
||||
|
||||
template <class U> Ref<U> as() const {
|
||||
Ref<U> ret;
|
||||
ret.ptr_ = std::dynamic_pointer_cast<U>(ptr_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool isValid() const { return ptr_ != nullptr; }
|
||||
|
||||
T &operator*() const {
|
||||
nnet_assert(isValid(), "Empty pointer.");
|
||||
return *ptr_;
|
||||
}
|
||||
|
||||
T *operator->() const {
|
||||
nnet_assert(isValid(), "Empty pointer.");
|
||||
return ptr_.get();
|
||||
}
|
||||
|
||||
T *get() const {
|
||||
nnet_assert(isValid(), "Empty pointer.");
|
||||
return ptr_.get();
|
||||
}
|
||||
|
||||
friend inline bool operator==(const Ref &lhs, nullptr_t) {
|
||||
return !lhs.isValid();
|
||||
}
|
||||
friend inline bool operator!=(const Ref &lhs, nullptr_t) {
|
||||
return !(lhs == nullptr);
|
||||
}
|
||||
explicit operator bool() const { return ptr_ != nullptr; }
|
||||
bool operator!() { return ptr_ == nullptr; }
|
||||
|
||||
void swap(Ref &__b) noexcept { ptr_.swap(__b.ptr_); }
|
||||
};
|
||||
|
||||
template <class T, class U,
|
||||
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
|
||||
Ref<T> as(const Ref<U> &ref) {
|
||||
return ref.template as<T>();
|
||||
}
|
||||
|
||||
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
|
||||
return Ref(make_shared<T>(std::forward<Params>(params)...));
|
||||
}
|
||||
|
||||
// Comparator for Ref
|
||||
template <typename T> struct is_ref : std::false_type {};
|
||||
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
|
||||
|
||||
template <class Tuple, std::size_t index = 0, bool address_based>
|
||||
typename std::enable_if_t<not is_ref<std::tuple_element_t<index, Tuple>>::value,
|
||||
bool>
|
||||
__ref_less(const Tuple &lhs, const Tuple &rhs) {
|
||||
if constexpr (index >=
|
||||
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
|
||||
return std::get<index>(lhs) < std::get<index>(rhs);
|
||||
else {
|
||||
if (std::get<index>(lhs) != std::get<index>(rhs))
|
||||
return std::get<index>(lhs) < std::get<index>(rhs);
|
||||
else
|
||||
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Tuple, std::size_t index = 0, bool address_based>
|
||||
typename std::enable_if_t<is_ref<std::tuple_element_t<index, Tuple>>::value and
|
||||
not address_based,
|
||||
bool>
|
||||
__ref_less(const Tuple &lhs, const Tuple &rhs) {
|
||||
if constexpr (index >=
|
||||
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
|
||||
return std::get<index>(lhs)->less(std::get<index>(rhs));
|
||||
else {
|
||||
if (std::get<index>(lhs)->neq(std::get<index>(rhs)))
|
||||
return std::get<index>(lhs)->less(std::get<index>(rhs));
|
||||
else
|
||||
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Tuple, std::size_t index = 0, bool address_based>
|
||||
typename std::enable_if_t<
|
||||
is_ref<std::tuple_element_t<index, Tuple>>::value and address_based, bool>
|
||||
__ref_less(const Tuple &lhs, const Tuple &rhs) {
|
||||
if constexpr (index >=
|
||||
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
|
||||
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
|
||||
else {
|
||||
if (std::get<index>(lhs).get() != std::get<index>(rhs).get())
|
||||
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
|
||||
else
|
||||
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Tuple> bool ref_addr_less(const Tuple &lhs, const Tuple &rhs) {
|
||||
return __ref_less<Tuple, 0, true>(lhs, rhs);
|
||||
}
|
||||
|
||||
template <class Tuple> bool ref_value_less(const Tuple &lhs, const Tuple &rhs) {
|
||||
return __ref_less<Tuple, 0, false>(lhs, rhs);
|
||||
}
|
||||
|
||||
template <class Tuple> class RefAddrLess {
|
||||
public:
|
||||
bool operator()(const Tuple &a, const Tuple &b) const {
|
||||
return ref_addr_less(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Tuple> class RefValueLess {
|
||||
public:
|
||||
bool operator()(const Tuple &a, const Tuple &b) const {
|
||||
return ref_value_less(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// make_ref_from_tuple
|
||||
template <typename _Tp, typename _Tuple, size_t... _Idx>
|
||||
constexpr Ref<_Tp> make_ref_from_tuple_impl(_Tuple &&__t,
|
||||
std::index_sequence<_Idx...>) {
|
||||
return make_ref<_Tp>(std::get<_Idx>(std::forward<_Tuple>(__t))...);
|
||||
}
|
||||
|
||||
template <typename _Tp, typename _Tuple>
|
||||
constexpr Ref<_Tp> make_ref_from_tuple(_Tuple &&__t) {
|
||||
return make_ref_from_tuple_impl<_Tp>(
|
||||
std::forward<_Tuple>(__t),
|
||||
std::make_index_sequence<std::tuple_size_v<std::decay_t<_Tuple>>>{});
|
||||
}
|
||||
|
||||
} // namespace nnet
|
||||
|
||||
// namespace std {
|
||||
|
||||
// template <class T> struct hash<ir::Ref<T>> {
|
||||
// hash<T *> hash_;
|
||||
// size_t operator()(const ir::Ref<T> &ref) const { return hash_(ref.get());
|
||||
// }
|
||||
// };
|
||||
|
||||
// } // namespace nnet
|
|
@ -0,0 +1,158 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "expr.h"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
namespace nnet {
|
||||
|
||||
class RoutineNode;
|
||||
class MatmulNode;
|
||||
class ElementWiseNode;
|
||||
using Routine = Ref<RoutineNode>;
|
||||
using Matmul = Ref<MatmulNode>;
|
||||
using ElementWise = Ref<ElementWiseNode>;
|
||||
|
||||
#define DEFINE_GETTYPE(CLASS) \
|
||||
RoutineType getType() const override { return RoutineType::CLASS##Type; }
|
||||
|
||||
class RoutineNode {
|
||||
protected:
|
||||
Expr expr;
|
||||
vector<Tensor> inputs;
|
||||
|
||||
public:
|
||||
RoutineNode(Expr _expr, const vector<Tensor> &_inputs);
|
||||
virtual string toReadable() const = 0;
|
||||
const Expr &getExpr() const { return expr; }
|
||||
const vector<Tensor> &getInputs() const { return inputs; }
|
||||
virtual RoutineType getType() const = 0;
|
||||
};
|
||||
|
||||
using MatmulArgs = tuple<int, // b
|
||||
int, // m
|
||||
int, // n
|
||||
int, // k
|
||||
bool, // transa
|
||||
bool>; // transb
|
||||
|
||||
class MatmulNode : public RoutineNode {
|
||||
int b, m, n, k;
|
||||
bool transa, transb;
|
||||
|
||||
public:
|
||||
MatmulNode(Expr _source, Tensor A, Tensor B, int _b, int _m, int _n, int _k,
|
||||
bool _transa, bool _transb)
|
||||
: RoutineNode(_source, {A, B}), b(_b), m(_m), n(_n), k(_k),
|
||||
transa(_transa), transb(_transb) {}
|
||||
DEFINE_GETTYPE(MatmulNode);
|
||||
|
||||
string toReadable() const override;
|
||||
|
||||
friend bool operator==(const MatmulNode &lhs, const MatmulNode &rhs);
|
||||
MatmulArgs getArgs() { return tuple(b, m, n, k, transa, transb); }
|
||||
};
|
||||
|
||||
using ConvArgs = tuple<int, // ph
|
||||
int, // pw
|
||||
int, // sh
|
||||
int, // sw
|
||||
int, // dh
|
||||
int>; // dw
|
||||
|
||||
class ConvNode : public RoutineNode {
|
||||
int ph, pw;
|
||||
int sh, sw;
|
||||
int dh, dw;
|
||||
|
||||
public:
|
||||
ConvNode(Expr _source, Tensor A, Tensor K, int _ph, int _pw, int _sh = 1,
|
||||
int _sw = 1, int _dh = 1, int _dw = 1)
|
||||
: RoutineNode(_source, {A, K}), ph(_ph), pw(_pw), sh(_sh), sw(_sw),
|
||||
dh(_dh), dw(_dw) {}
|
||||
DEFINE_GETTYPE(ConvNode);
|
||||
|
||||
string toReadable() const override;
|
||||
vector<int> getShape() const;
|
||||
friend bool operator==(const ConvNode &lhs, const ConvNode &rhs);
|
||||
ConvArgs getArgs() const;
|
||||
};
|
||||
|
||||
class ElementWiseNode : public RoutineNode {
|
||||
vector<int> outputShape;
|
||||
|
||||
public:
|
||||
// _outputShape is redundent, but expr is still missing for DLT.
|
||||
ElementWiseNode(Expr _source, vector<Tensor> _inputs,
|
||||
vector<int> _outputShape)
|
||||
: RoutineNode(_source, _inputs), outputShape(_outputShape) {}
|
||||
DEFINE_GETTYPE(ElementWiseNode);
|
||||
|
||||
string toReadable() const override;
|
||||
/**
|
||||
* @brief Get the Estimated Time of mem bound OP.
|
||||
*
|
||||
* @return double Time in ms.
|
||||
*/
|
||||
double getEstimatedTime() const;
|
||||
const vector<int> &getOutputShape() const { return outputShape; }
|
||||
};
|
||||
|
||||
using G2bmmArgs = tuple<int, // b
|
||||
int, // m
|
||||
int, // w
|
||||
int, // k
|
||||
int>; // dilation
|
||||
class G2bmmNode : public RoutineNode {
|
||||
int b, m, w, k;
|
||||
|
||||
public:
|
||||
G2bmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int k,
|
||||
int d = 1)
|
||||
: RoutineNode(source, {A, B}), b(b), m(m), w(w), k(k) {
|
||||
assert(d == 1);
|
||||
}
|
||||
DEFINE_GETTYPE(G2bmmNode);
|
||||
|
||||
vector<int> getShape() const;
|
||||
string toReadable() const override;
|
||||
G2bmmArgs getArgs() const;
|
||||
};
|
||||
|
||||
using GbmmArgs = tuple<int, // b
|
||||
int, // m
|
||||
int, // w
|
||||
int, // n
|
||||
int>; // dilation
|
||||
class GbmmNode : public RoutineNode {
|
||||
int b, m, w, n;
|
||||
|
||||
public:
|
||||
GbmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int n,
|
||||
int d = 1)
|
||||
: RoutineNode(source, {A, B}), b(b), m(m), w(w), n(n) {
|
||||
assert(d == 1);
|
||||
}
|
||||
DEFINE_GETTYPE(GbmmNode);
|
||||
|
||||
vector<int> getShape() const;
|
||||
string toReadable() const override;
|
||||
GbmmArgs getArgs() const;
|
||||
};
|
||||
|
||||
// Pretty output for dbg with shared_ptr
|
||||
template <typename T, typename std::enable_if_t<
|
||||
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
|
||||
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
|
||||
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
|
||||
return os;
|
||||
}
|
||||
|
||||
// Pretty output for dbg with shared_ptr
|
||||
template <typename T, typename std::enable_if_t<
|
||||
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
|
||||
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
|
||||
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,28 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "derivator.h"
|
||||
|
||||
// clang-format off
|
||||
#define CAT(A, B) A##B
|
||||
#define SELECT(NAME, NUM) CAT(NAME##_, NUM)
|
||||
#define GET_COUNT( _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ... ) COUNT
|
||||
#define VA_SIZE( ... ) GET_COUNT( __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 )
|
||||
#define VA_SELECT( NAME, ... ) SELECT( NAME, VA_SIZE(__VA_ARGS__) )(__VA_ARGS__)
|
||||
|
||||
#define _DEFVAR_1(name) auto name = make_ref<VarNode>(#name);
|
||||
#define _DEFVAR_2(name, ...) _DEFVAR_1(name); _DEFVAR_1(__VA_ARGS__)
|
||||
#define _DEFVAR_3(name, ...) _DEFVAR_1(name); _DEFVAR_2(__VA_ARGS__)
|
||||
#define _DEFVAR_4(name, ...) _DEFVAR_1(name); _DEFVAR_3(__VA_ARGS__)
|
||||
#define _DEFVAR_5(name, ...) _DEFVAR_1(name); _DEFVAR_4(__VA_ARGS__)
|
||||
#define _DEFVAR_6(name, ...) _DEFVAR_1(name); _DEFVAR_5(__VA_ARGS__)
|
||||
#define _DEFVAR_7(name, ...) _DEFVAR_1(name); _DEFVAR_6(__VA_ARGS__)
|
||||
#define _DEFVAR_8(name, ...) _DEFVAR_1(name); _DEFVAR_7(__VA_ARGS__)
|
||||
#define _DEFVAR_9(name, ...) _DEFVAR_1(name); _DEFVAR_8(__VA_ARGS__)
|
||||
#define DEFINE_VAR(...) VA_SELECT(_DEFVAR, __VA_ARGS__)
|
||||
// clang-format on
|
||||
|
||||
namespace nnet {
|
||||
int matchExprResult(Derivator &derivator, string fn);
|
||||
bool checkExprLogSame(string fnPrefix, int start, int end);
|
||||
bool checkExprsEquvivalence(VecExpr exprs);
|
||||
} // namespace nnet
|
|
@ -0,0 +1,128 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "derivator.h"
|
||||
#include "expr.h"
|
||||
#include "routine.h"
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
template <typename FType> class Functor;
|
||||
|
||||
template <typename R, typename... Args> class Functor<R(Args...)> {
|
||||
protected:
|
||||
int verbose;
|
||||
|
||||
// FIXME: scope should be protected
|
||||
public:
|
||||
Functor(int _verobse = 0) : verbose(_verobse) {}
|
||||
virtual ~Functor() = default;
|
||||
#define DISPATCH(CLASS) \
|
||||
case NodeType::CLASS##Type: \
|
||||
return this->visit_(as<CLASS>(c), std::forward<Args>(args)...); \
|
||||
break
|
||||
|
||||
#define FUNCTOR_DEFAULT \
|
||||
{ return visitDefault(c, std::forward<Args>(args)...); }
|
||||
|
||||
virtual R dispatch(const Expr &c, Args... args) {
|
||||
switch (c->getType()) {
|
||||
DISPATCH(ConstantNode);
|
||||
DISPATCH(BinaryOpNode);
|
||||
DISPATCH(RangeOpNode);
|
||||
DISPATCH(SubscriptNode);
|
||||
DISPATCH(TensorNode);
|
||||
DISPATCH(VarNode);
|
||||
DISPATCH(FuncNode);
|
||||
default:
|
||||
nnet_assert(0, "Unknown type");
|
||||
return R();
|
||||
}
|
||||
}
|
||||
|
||||
virtual R visit_(const Constant &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const BinaryOp &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const RangeOp &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const Subscript &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const Var &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) {
|
||||
dbg(*c);
|
||||
nnet_assert(0, "Reach unimplemented visit function.");
|
||||
return R();
|
||||
};
|
||||
|
||||
[[deprecated("Define explicit methods for public access.")]] R
|
||||
operator()(const Expr &e, Args... args) {
|
||||
return dispatch(e, std::forward<Args>(args)...);
|
||||
}
|
||||
#undef FUNCTOR_DEFAULT
|
||||
#undef DISPATCH
|
||||
};
|
||||
|
||||
class Mutator : public Functor<Expr()> {
|
||||
public:
|
||||
Mutator(int _verobse = 0) : Functor(_verobse) {}
|
||||
Expr visit_(const Constant &c) override;
|
||||
Expr visit_(const BinaryOp &c) override;
|
||||
Expr visit_(const RangeOp &c) override;
|
||||
Expr visit_(const Subscript &c) override;
|
||||
Expr visit_(const Var &c) override;
|
||||
Expr visit_(const Tensor &c) override;
|
||||
Expr visit_(const Func &c) override;
|
||||
};
|
||||
|
||||
// template <typename... Args>
|
||||
// class SingleStageVisitor : public Functor<void, Args...> {
|
||||
// public:
|
||||
// SingleStageVisitor(int _verobse = 0) : Functor<R, Args...>(_verobse) {}
|
||||
// // R visit(const Constant &c) override ;
|
||||
// R visit_(const BinaryOp &c) override {
|
||||
// if (verbose)
|
||||
// dbg(*c);
|
||||
// this->dispatch(c->getLhs());
|
||||
// this->dispatch(c->getRhs());
|
||||
// }
|
||||
// R visit_(const RangeOp &c) override {
|
||||
// if (verbose)
|
||||
// dbg(*c);
|
||||
// this->dispatch(ret->getSummand());
|
||||
// // NOT visit iterators and its ranges
|
||||
// }
|
||||
// R visit_(const Subscript &c) override {
|
||||
// if (verbose)
|
||||
// dbg(*c);
|
||||
// this->dispatch(ret->getObject());
|
||||
// for (size_t i = 0; i < ret->getDims(); ++i)
|
||||
// this->dispatch(ret->getIndex(i));
|
||||
// }
|
||||
// // R visit(const Var &c) override;
|
||||
// // R visit(const Tensor &c) override;
|
||||
// };
|
||||
|
||||
// } // namespace nnet
|
||||
// #include "nnet/Visitor/ReplaceVariable.h"
|
||||
// #include "nnet/Visitor/StrideVisitor.h"
|
||||
// namespace nnet {
|
||||
|
||||
class ExprTreeVisitor : public Functor<void(void)> {
|
||||
private:
|
||||
bool inBinary, inRange, inSub, inTensor;
|
||||
|
||||
public:
|
||||
ExprTreeVisitor(bool _inBinary = 1, bool _inRange = 1, bool _inSub = 1,
|
||||
bool _inTensor = 1, int _verobse = 0)
|
||||
: Functor(_verobse), inBinary(_inBinary), inRange(_inRange),
|
||||
inSub(_inSub), inTensor(_inTensor) {}
|
||||
void visit_(const Constant &c) override;
|
||||
void visit_(const BinaryOp &c) override;
|
||||
void visit_(const RangeOp &c) override;
|
||||
void visit_(const Subscript &c) override;
|
||||
void visit_(const Var &c) override;
|
||||
void visit_(const Tensor &c) override;
|
||||
void visit_(const Func &c) override;
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,25 @@
|
|||
#include "nnet/Pass/MatchComputationKernel.h"
|
||||
#include "nnet/Visitor/PatternMatcher.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
// RE: is this duplicate with Rule6KenerlMatching?
|
||||
void MatchComputationKernel::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
nnet_assert(derivator.getSearchState() == 2, __LINE__);
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
// Build wrapper stages for enforce axis starts from 0
|
||||
PatternMatcher patternMatcher(derivator, cur);
|
||||
cur = patternMatcher.getOffsetCur();
|
||||
|
||||
auto matches = patternMatcher.matchWithPattern(
|
||||
cur, getPattern(derivator.getTargetOp()));
|
||||
matches = patternMatcher.applyWrapper(matches);
|
||||
|
||||
for (auto newCur : matches) {
|
||||
derivator.setSearchState(3);
|
||||
nextStep(origin, depth, rCur, newCur);
|
||||
derivator.setSearchState(2);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,23 @@
|
|||
#include "nnet/Pass/MatchMemBoundKernel.h"
|
||||
#include "nnet/Visitor/InputVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void MatchMemBoundKernel::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
// FIXME: Whether the Formula is a Membound OP should be checked.
|
||||
nnet_assert(derivator.getSearchState() == 3, __LINE__);
|
||||
nnet_assert(origin.root.get() == rCur.get(),
|
||||
"Only match the entire formula as a Membound Op");
|
||||
auto rangeOp = as<RangeOpNode>(origin.root);
|
||||
const auto &inputs = InputVisitor().getInputs(rangeOp);
|
||||
auto source =
|
||||
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
|
||||
auto tensor =
|
||||
makeTensor(newTensorName(), rangeOp->getOutputShape(), {}, source);
|
||||
// The original code directly appends candidate. But it seems should be done
|
||||
// by the search.
|
||||
// appendCanddiate(as<TensorNode>(tensor), depth);
|
||||
nextStep(origin, depth, rCur, tensor);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,58 @@
|
|||
#include "nnet/Pass/Pass.h"
|
||||
#include "nnet/Visitor/CloneMutator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Pass::Pass(Derivator &derivator, const string &passName)
|
||||
: derivator(derivator), passName(passName),
|
||||
enableLogging(derivator.getLogMode() != Derivator::LogMode::NoLog),
|
||||
enableDebug(false) {}
|
||||
|
||||
Pass::~Pass() = default;
|
||||
|
||||
void Pass::setEnableLogging(bool value) { enableLogging = value; }
|
||||
|
||||
void Pass::setEnableDebug(bool value) { enableDebug = value; }
|
||||
|
||||
void Pass::run(Formula &origin, int dfsDepth, Expr &rCur) {
|
||||
initialize(origin, rCur);
|
||||
transform(origin, dfsDepth, rCur);
|
||||
finalize();
|
||||
}
|
||||
|
||||
void Pass::initialize(Formula &origin, const Expr &rCur) {}
|
||||
|
||||
void Pass::finalize() {}
|
||||
|
||||
Var Pass::getNewVar() { return derivator.getNewVar(); }
|
||||
|
||||
string Pass::newTensorName() { return derivator.newTensorName(); }
|
||||
|
||||
void Pass::nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
|
||||
const string &ruleMsg) {
|
||||
// push rule action description
|
||||
if (enableLogging) {
|
||||
rCur.swap(newCur);
|
||||
derivator.pushIntermediateState(origin.root);
|
||||
rCur.swap(newCur);
|
||||
derivator.pushRuleState(passName);
|
||||
derivator.pushRuleMsg(ruleMsg);
|
||||
}
|
||||
|
||||
if (enableDebug) {
|
||||
// In debug mode, do not recur but save the transformed state
|
||||
transformations.emplace_back(CloneMutator().clone(newCur));
|
||||
} else
|
||||
derivator.nextStep(origin, depth, rCur, newCur);
|
||||
|
||||
// pop rule action description
|
||||
if (enableLogging) {
|
||||
derivator.popIntermediateState();
|
||||
derivator.popRuleState();
|
||||
derivator.popRuleMsg();
|
||||
}
|
||||
}
|
||||
|
||||
const VecExpr &Pass::getTransformations() { return transformations; }
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,134 @@
|
|||
#include "nnet/Pass/Rule1VariableSplit.h"
|
||||
#include "nnet/Visitor/ReplaceVariable.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule1VariableSplit::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
vector<Replace> replaces = getSplitableVar(cur);
|
||||
// for (const auto &replace : replaces)
|
||||
// dbg(replace.oldIters, replace.newIters, replace.psis,
|
||||
// replace.newVarRanges);
|
||||
for (const auto &replace : replaces) {
|
||||
auto replacedSummand = replaceIters(cur->getSummand(), replace);
|
||||
if (!replacedSummand) {
|
||||
// TODO: if a real getMergableExprs is implemented, this case should
|
||||
// be an error. Since the expr should appear in the AST.
|
||||
dbg("Warning: No replacment happens.");
|
||||
continue;
|
||||
}
|
||||
auto inner =
|
||||
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
|
||||
// build the outerRange{innerRange}[indexForInner] to do DLT
|
||||
Expr nextCur = nullptr;
|
||||
if (replace.iteratorType == IterationType::Loop) {
|
||||
auto subscriptedInner =
|
||||
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
|
||||
nextCur = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||
} else
|
||||
nextCur = inner;
|
||||
|
||||
string msg = "====== END rule1 VariableSplit: ";
|
||||
dbg(msg, replace.oldIters, replace.newIters, replace.phis,
|
||||
replace.psis);
|
||||
msg = replace.toReadable();
|
||||
nextStep(origin, depth, rCur, nextCur, msg);
|
||||
}
|
||||
}
|
||||
|
||||
vector<Replace> Rule1VariableSplit::getSplitableVar(const RangeOp &rangeOp) {
|
||||
vector<Replace> ret;
|
||||
// Split strategy
|
||||
vector<int> SumFactors, LoopFactors;
|
||||
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
|
||||
SumFactors = {3};
|
||||
LoopFactors = {4};
|
||||
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
|
||||
SumFactors = {2, 3};
|
||||
// LoopFactors = {3, 4};
|
||||
LoopFactors = {4};
|
||||
} else
|
||||
nnet_unimplemented_halt();
|
||||
|
||||
// Split Sum variable
|
||||
for (const int k : SumFactors) {
|
||||
for (const auto &[var, range] : rangeOp->getSumVarRanges()) {
|
||||
int len = range.second - range.first;
|
||||
auto p1 = getNewVar(); // p1=i/k
|
||||
auto p2 = getNewVar(); // p2=i%k
|
||||
if (len > 10 || len <= k || len % k != 0)
|
||||
continue;
|
||||
|
||||
Range range1, range2;
|
||||
if (range.first < 0) {
|
||||
nnet_unimplemented_halt();
|
||||
// FIXME: this must be ERROR
|
||||
range1.first = range.first / k;
|
||||
range1.second = range1.first + len / k;
|
||||
range2.first = -k / 2;
|
||||
range2.second = range2.first + k;
|
||||
} else if (range.first == 0) {
|
||||
range1.first = 0;
|
||||
range1.second = len / k;
|
||||
range2.first = 0;
|
||||
range2.second = k;
|
||||
} else {
|
||||
nnet_unimplemented_continue();
|
||||
continue;
|
||||
}
|
||||
Replace replace{.iteratorType = IterationType::Sum,
|
||||
.oldIters = {var},
|
||||
.newIters = {p1, p2},
|
||||
.phis = {},
|
||||
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
|
||||
.newVarRanges = {{p1, range1}, {p2, range2}}};
|
||||
ret.emplace_back(replace);
|
||||
}
|
||||
}
|
||||
for (const int k : LoopFactors) {
|
||||
// Split Loop variable
|
||||
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||
const int len = range.second - range.first;
|
||||
// Debug HACK for dilated SG2BMM
|
||||
if (derivator.getPassMode() == Derivator::PassMode::Debug &&
|
||||
!(var->getName() == "m" && len % k == 0))
|
||||
continue;
|
||||
|
||||
// Illeagel conditions
|
||||
if (range.second - range.first <= k ||
|
||||
(range.second - range.first) % k != 0)
|
||||
continue;
|
||||
// Unsupport conditions
|
||||
if (range.first != 0)
|
||||
continue;
|
||||
auto p1 = getNewVar(); // p1=i/k
|
||||
auto p2 = getNewVar(); // p2=i%k
|
||||
Range range1(0, len / k);
|
||||
Range range2(0, k);
|
||||
nnet_assert(range1.second > 0 && range2.second > 0,
|
||||
"Empty loop dim");
|
||||
Replace replace{.iteratorType = IterationType::Loop,
|
||||
.oldIters = {var},
|
||||
.newIters = {p1, p2},
|
||||
.phis = {var / 4, var % 4},
|
||||
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
|
||||
.newVarRanges = {{p1, range1}, {p2, range2}}};
|
||||
ret.emplace_back(replace);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Expr Rule1VariableSplit::replaceIters(Expr cur, const Replace &replace) {
|
||||
// TODO [feature]: support multiple replacements in one mutator
|
||||
if (replace.oldIters.size() != 1) {
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
auto replaceMutator =
|
||||
ReplaceVariable(replace.oldIters.at(0), replace.psis.at(0));
|
||||
auto ret = replaceMutator(cur);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,186 @@
|
|||
#include "nnet/Pass/Rule2VariableMerging.h"
|
||||
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule2VariableMerging::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
// Extract r and s
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
vector<Replace> replaces = getMergableReplaces(cur, depth);
|
||||
// dbg("Start rule2VariableMerging", depth, mergableExprs, *cur);
|
||||
for (const auto &replace : replaces) {
|
||||
if (replace.iteratorType != IterationType::Loop) {
|
||||
nnet_unimplemented_continue();
|
||||
continue;
|
||||
}
|
||||
// replace vars in summand
|
||||
auto replacedSummand = ReplaceKit::replaceMultipleExprs(
|
||||
cur->getSummand(), replace.oldIters, replace.psis, true);
|
||||
// replace var in rangeOp
|
||||
auto inner =
|
||||
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
|
||||
// If OOB happens, this transformation is skipped
|
||||
if (CheckOOBVisitor().checkRangeOp(inner))
|
||||
continue;
|
||||
// build the outerRange{innerRange}[indexForInner] to do DLT
|
||||
auto subscriptedInner =
|
||||
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
|
||||
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||
|
||||
// next searching step
|
||||
string msg = replace.toReadable();
|
||||
nextStep(origin, depth, rCur, outer, msg);
|
||||
}
|
||||
}
|
||||
|
||||
vector<Replace> Rule2VariableMerging::getMergableReplaces(RangeOp rangeOp,
|
||||
int depth) {
|
||||
vector<Replace> ret;
|
||||
IteratorTable exprIT;
|
||||
if (!exprIT.analyzeExpr(rangeOp)) {
|
||||
nnet_unimplemented_continue();
|
||||
return ret;
|
||||
}
|
||||
exprIT.buildTableWithDefaultMap();
|
||||
const auto &strideInAllDim = exprIT.getStrideInDim();
|
||||
|
||||
set<pair<Iterator, Iterator>, RefValueLess<pair<Iterator, Iterator>>>
|
||||
checkedIterPairs{};
|
||||
// strideInAllDim: [tensorID][dimOfTensor][Iterator]=stride
|
||||
for (size_t tensorID = 0; tensorID < strideInAllDim.size(); ++tensorID) {
|
||||
const auto &strideInDimsOfATensor = strideInAllDim[tensorID];
|
||||
for (const PtrMap<Iterator, int> &strideInADim :
|
||||
strideInDimsOfATensor) {
|
||||
for (const auto &it1 : strideInADim) {
|
||||
for (const auto &it2 : strideInADim) {
|
||||
// Backdoor for rule-based search
|
||||
if (substituteRules.count(depth)) {
|
||||
if (substituteRules[depth].at(0)->neq(it1.first))
|
||||
continue;
|
||||
if (substituteRules[depth].at(1)->neq(it2.first))
|
||||
continue;
|
||||
}
|
||||
if (!(it1.first->equal(it2.first) &&
|
||||
it1.second == it2.second) &&
|
||||
rangeOp->hasLoopVar(it1.first) &&
|
||||
rangeOp->hasLoopVar(it2.first)) {
|
||||
// 2 iters -> 2 iters
|
||||
if (auto opt = getReplaceMappingTwoLoopIters(rangeOp,
|
||||
it1, it2))
|
||||
ret.emplace_back(*opt);
|
||||
|
||||
// 2 iters -> 1 iter
|
||||
const auto iterPair = pair(it1.first, it2.first);
|
||||
if (!checkedIterPairs.count(iterPair)) {
|
||||
checkedIterPairs.insert(iterPair);
|
||||
if (auto opt = getReplaceMergingTwoLoopIters(
|
||||
rangeOp, it1, it2, exprIT, tensorID))
|
||||
ret.emplace_back(*opt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
optional<Replace> Rule2VariableMerging::getReplaceMergingTwoLoopIters(
|
||||
const RangeOp &rangeOp, pair<Iterator, int> pairA,
|
||||
pair<Iterator, int> pairB, const IteratorTable &exprIT, int tensorID) {
|
||||
// 1*A + sb*B -> C
|
||||
// A=C%sb, B=C/sb
|
||||
// ax+by->z, a=1 or -1
|
||||
// For a>0 and b>0 : x=z%b, y=z/b
|
||||
auto x = pairA.first, y = pairB.first;
|
||||
int a = pairA.second, b = pairB.second;
|
||||
if (abs(a) != 1 || abs(a) * abs(b) <= 0)
|
||||
return {};
|
||||
if (a < 0 && b > 0) { // The only unhandled case
|
||||
nnet_unimplemented_continue();
|
||||
return {};
|
||||
}
|
||||
// negative substitution happens only if can be totally merged. So if the
|
||||
// variable appears in another index, skip it.
|
||||
if (a < 0 || b < 0) {
|
||||
if (exprIT.getNumInputs() > 1) {
|
||||
if (exprIT.getStridesInTensor(x, 1 - tensorID) != 0)
|
||||
return {};
|
||||
if (exprIT.getStridesInTensor(y, 1 - tensorID) != 0)
|
||||
return {};
|
||||
}
|
||||
}
|
||||
Range rangeX = rangeOp->getVarRange(x).second,
|
||||
rangeY = rangeOp->getVarRange(y).second;
|
||||
if (rangeX.first != 0 || rangeY.first != 0)
|
||||
return {};
|
||||
int lenX = rangeX.second - rangeX.first;
|
||||
if (abs(b) != lenX)
|
||||
return {};
|
||||
auto z = getNewVar();
|
||||
|
||||
Range rangeExpr{0, 1}; // 1 is the open interval compensation
|
||||
auto calcRangeExpr = [&rangeExpr](int stride, const Range &r) {
|
||||
if (stride > 0) {
|
||||
rangeExpr.first += stride * r.first;
|
||||
rangeExpr.second += stride * (r.second - 1);
|
||||
} else {
|
||||
rangeExpr.first += stride * (r.second - 1);
|
||||
rangeExpr.second += stride * r.first;
|
||||
}
|
||||
};
|
||||
calcRangeExpr(a, rangeX);
|
||||
calcRangeExpr(b, rangeY);
|
||||
|
||||
// build the phi/psi for index transformation
|
||||
// phi: j_x=(i_x...), psi: i_x=(j_x...)
|
||||
auto ret = optional<Replace>();
|
||||
ret.emplace();
|
||||
ret->iteratorType = IterationType::Loop;
|
||||
ret->newIters = {z};
|
||||
ret->oldIters = {x, y};
|
||||
ret->phis = {a * x + b * y - rangeExpr.first};
|
||||
// For b < 0, the psis are not an equavalent replace. Since it must be
|
||||
// simplified (z/b and z%b will be merged), the only important thing is
|
||||
// their strides should be mergable. To merge the strides, an extra minus
|
||||
// are introduced if their stride is negative.
|
||||
ret->psis = {a * (z % b) + a * rangeExpr.first, (b > 0 ? 1 : -1) * (z / b)};
|
||||
ret->newVarRanges = {{z, {0, rangeExpr.second - rangeExpr.first}}};
|
||||
return ret;
|
||||
}
|
||||
|
||||
optional<Replace>
|
||||
Rule2VariableMerging::getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
|
||||
pair<Iterator, int> pairA,
|
||||
pair<Iterator, int> pairB) {
|
||||
// the first iterator is replaced, the second remains
|
||||
auto i1 = pairA.first, i2 = pairB.first;
|
||||
int sa = pairA.second, sb = pairB.second;
|
||||
// TODO: can be relaxed to sb|sb
|
||||
if (sa != 1 || sb == 0)
|
||||
return {};
|
||||
if (sb < 0) {
|
||||
nnet_unimplemented_continue();
|
||||
return {};
|
||||
}
|
||||
Range rangeA = rangeOp->getVarRange(i1).second;
|
||||
Range rangeB = rangeOp->getVarRange(i2).second;
|
||||
auto j1 = getNewVar(), j2 = getNewVar();
|
||||
Range rangeJ1, rangeJ2 = rangeB;
|
||||
assert(pairA.second == 1);
|
||||
rangeJ1.first = rangeA.first + rangeB.first * sb;
|
||||
rangeJ1.second = rangeA.second + (rangeB.second - 1) * sb;
|
||||
// build the phi/psi for index transformation
|
||||
// phi: j_x=(i_x...), psi: i_x=(j_x...)
|
||||
auto ret = optional<Replace>();
|
||||
ret.emplace();
|
||||
ret->iteratorType = IterationType::Loop;
|
||||
ret->newIters = {j1, j2};
|
||||
ret->oldIters = {i1, i2};
|
||||
ret->newVarRanges = {{j1, rangeJ1}, {j2, rangeJ2}};
|
||||
ret->phis = {sa * i1 + sb * i2, i2};
|
||||
ret->psis = {j1 - (sb / sa) * j2, j2};
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,82 @@
|
|||
#include "nnet/Pass/Rule3StageSplit.h"
|
||||
#include "nnet/permutation.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule3StageSplit::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
vector<vector<Iterator>> splitSchemes = getSplitSummationIters(cur);
|
||||
|
||||
for (const auto &varSplit : splitSchemes) {
|
||||
|
||||
bool isSplittable = false;
|
||||
for (const auto &splitted : varSplit)
|
||||
if (cur->hasSumVar(splitted))
|
||||
isSplittable = true;
|
||||
assert(isSplittable);
|
||||
|
||||
const vector<VarRangePair> loopVars = cur->getLoopVarRanges(),
|
||||
sumVars = cur->getSumVarRanges();
|
||||
// move iterators from Sigma to Loop
|
||||
vector<VarRangePair> innerLoopVars, innerSumVars, outerSumVars;
|
||||
VecExpr indexForInner;
|
||||
for (const auto &kv : sumVars) {
|
||||
bool isSplitted = false;
|
||||
for (const auto &iter : varSplit)
|
||||
if (iter == kv.first->getName())
|
||||
isSplitted = true;
|
||||
if (isSplitted) {
|
||||
innerLoopVars.emplace_back(kv);
|
||||
outerSumVars.emplace_back(kv);
|
||||
} else
|
||||
innerSumVars.emplace_back(kv);
|
||||
}
|
||||
innerLoopVars.insert(innerLoopVars.end(), loopVars.begin(),
|
||||
loopVars.end());
|
||||
for (const auto &[var, _] : innerLoopVars)
|
||||
indexForInner.emplace_back(var);
|
||||
|
||||
// if no sum iterator, the stage is redundant
|
||||
assert(!innerSumVars.empty());
|
||||
auto inner =
|
||||
makeRangeOperator(innerLoopVars, innerSumVars, cur->getSummand());
|
||||
auto subscriptedInner = make_ref<SubscriptNode>(inner, indexForInner);
|
||||
auto outer = makeRangeOperator(cur->getLoopVarRanges(), outerSumVars,
|
||||
subscriptedInner);
|
||||
outer->setPaddings(cur->getPaddings());
|
||||
|
||||
// next searching step
|
||||
string msg = "Separate sum iters: " + serializeVec(varSplit);
|
||||
nextStep(origin, depth, rCur, outer, msg);
|
||||
}
|
||||
}
|
||||
|
||||
vector<vector<Iterator>>
|
||||
Rule3StageSplit::getSplitSummationIters(RangeOp rangeOp) {
|
||||
// set<string> varSplit = {"r", "s", "i3", "i13"};
|
||||
vector<vector<Iterator>> ret;
|
||||
// Rule-based Hint
|
||||
// vector<vector<Iterator>> heuristics = {{"r", "s"}, {"i3", "i13"}};
|
||||
// for (const auto &iterSet : heuristics) {
|
||||
// bool notExist = false;
|
||||
// for (const auto &iter : iterSet)
|
||||
// if (!rangeOp->hasSumVar(iter))
|
||||
// notExist = true;
|
||||
// if (!notExist)
|
||||
// ret.emplace_back(iterSet);
|
||||
// }
|
||||
// if (!rulesOverall.empty())
|
||||
// return ret;
|
||||
vector<Iterator> sumIters;
|
||||
for (const auto &[iter, range] : rangeOp->getSumVarRanges())
|
||||
sumIters.emplace_back(iter);
|
||||
if (sumIters.size() <= 1)
|
||||
return ret;
|
||||
SubsetGenerator gen(sumIters);
|
||||
do {
|
||||
ret.emplace_back(gen.get());
|
||||
} while (gen.next());
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,91 @@
|
|||
#include "nnet/Pass/Rule4StageMerging.h"
|
||||
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||
#include "nnet/Visitor/ReplaceVariable.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule4StageMerging::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
success = rule4StageMerging(origin, depth, rCur, mergeStageWithCalc);
|
||||
}
|
||||
|
||||
bool Rule4StageMerging::isSuccessful() { return success; }
|
||||
|
||||
void Rule4StageMerging::setMergeStageWithCalc(bool value) {
|
||||
mergeStageWithCalc = value;
|
||||
}
|
||||
|
||||
bool Rule4StageMerging::rule4StageMerging(Formula &origin, int depth,
|
||||
Expr &rCur, bool mergeStageWithCalc) {
|
||||
auto rangeOp0 = as<RangeOpNode>(rCur);
|
||||
const Subscript &sub0 = as<SubscriptNode>(rangeOp0->getSummand());
|
||||
if (!sub0)
|
||||
return false;
|
||||
const auto &rangeOp1 = as<RangeOpNode>(sub0->getObject());
|
||||
if (!rangeOp1)
|
||||
return false;
|
||||
const auto &sub1 = as<SubscriptNode>(rangeOp1->getSummand());
|
||||
if (!sub1)
|
||||
return false;
|
||||
// merge stage with calculation only when mergeStageWithCalc=true
|
||||
if (!mergeStageWithCalc && !rangeOp1->getSumVarRanges().empty())
|
||||
return false;
|
||||
// Only propogate paddings in perfect nested dimension
|
||||
if (rangeOp1->hasPaddings()) {
|
||||
auto oldTensor = as<TensorNode>(sub1->getObject());
|
||||
if (!oldTensor) {
|
||||
nnet_unimplemented_continue();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
// repalce variables: iters of rangeOp1 repalced by indexes of sub0
|
||||
map<string, pair<Expr, Expr>> varMapping;
|
||||
assert(sub0->getDims() == rangeOp1->getLoopVarRanges().size());
|
||||
for (size_t i = 0; i < sub0->getDims(); ++i) {
|
||||
varMapping[rangeOp1->getLoopVar(i)->getName()] =
|
||||
pair(rangeOp1->getLoopVar(i), sub0->getIndex(i));
|
||||
}
|
||||
ReplaceVariable replaceVariable{varMapping};
|
||||
auto merged = make_ref<RangeOpNode>(*rangeOp0);
|
||||
merged->setSummand(replaceVariable(sub1));
|
||||
// a naive approach to propogate paddings
|
||||
if (rangeOp1->hasPaddings()) {
|
||||
auto oldTensor = as<TensorNode>(sub1->getObject());
|
||||
auto newTensor = make_ref<TensorNode>(*oldTensor);
|
||||
for (int i = 0; i < rangeOp1->getNumOutputDims(); ++i) {
|
||||
if (rangeOp1->getPaddings(i) == 0)
|
||||
continue;
|
||||
auto loopVar = rangeOp1->getLoopVar(i);
|
||||
// FIXME: in fact this var should not appear in other index as well,
|
||||
// which may result in OOB
|
||||
bool findSingleVarAsIndex = false;
|
||||
for (size_t subIndexID = 0; subIndexID < sub1->getDims();
|
||||
++subIndexID) {
|
||||
auto index = sub1->getIndex(subIndexID);
|
||||
if (auto indexVar = as<VarNode>(index);
|
||||
indexVar && (indexVar->equal(loopVar))) {
|
||||
newTensor->setPadding(subIndexID,
|
||||
newTensor->getPadding(subIndexID) +
|
||||
rangeOp1->getPaddings(i));
|
||||
findSingleVarAsIndex = true;
|
||||
}
|
||||
}
|
||||
if (!findSingleVarAsIndex) {
|
||||
nnet_unimplemented_continue();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
merged = as<RangeOpNode>(
|
||||
ReplaceNodeMutator().replace(merged, oldTensor.get(), newTensor));
|
||||
assert(merged != nullptr);
|
||||
}
|
||||
// Merge inner stage sums
|
||||
if (!rangeOp1->getSumVarRanges().empty())
|
||||
merged->setSumIterator(rangeOp1->getSumVarRanges());
|
||||
|
||||
// next searching step
|
||||
// if mergeStageWithCalc, depth counts for invocation in rule-based search
|
||||
nextStep(origin, (mergeStageWithCalc) ? depth : depth - 1, rCur, merged);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,72 @@
|
|||
#include "nnet/Pass/Rule5RangeRelaxation.h"
|
||||
#include "nnet/Visitor/RangeRelaxFunctor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule5RangeRelaxation::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
rule5RangeRelaxation(origin, depth, rCur);
|
||||
}
|
||||
|
||||
Expr Rule5RangeRelaxation::rule5RangeRelaxation(Formula &origin, int depth,
|
||||
Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
if (cur->hasPaddings()) {
|
||||
// string msg = "====== END rule5RangeRelaxation: Paddings exist \n";
|
||||
// dbg(msg);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Infer meaningful calculation range
|
||||
RangeRelaxFunctor rangeRexlaxtionFunctor{cur};
|
||||
RangeMap rangeMap = rangeRexlaxtionFunctor(cur);
|
||||
auto relaxedCur = make_ref<RangeOpNode>(*cur);
|
||||
bool isRelaxed = false;
|
||||
vector<int> paddings;
|
||||
// check whether narrow the calculation range
|
||||
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
|
||||
const auto &[iter, iterRange] =
|
||||
cur->getVarRange(IterationType::Loop, i);
|
||||
if (auto it = rangeMap.find(iter); it != rangeMap.end()) {
|
||||
// intersection of validRange and iterRange is necessary computation
|
||||
// TODO: it is redundant with RangeRelaxFunctor::intersectRangeMaps.
|
||||
// An independent Range class might be necessary.
|
||||
const Range &validRange = it->second;
|
||||
Range relaxedRange{max(iterRange.first, validRange.first),
|
||||
min(iterRange.second, validRange.second)};
|
||||
if (relaxedRange != iterRange) {
|
||||
isRelaxed = true;
|
||||
relaxedCur->setVarRange(IterationType::Loop, i,
|
||||
{iter, relaxedRange});
|
||||
paddings.emplace_back(
|
||||
max(relaxedRange.first - iterRange.first,
|
||||
iterRange.second - relaxedRange.second));
|
||||
} else
|
||||
paddings.emplace_back(0);
|
||||
} else
|
||||
paddings.emplace_back(0);
|
||||
}
|
||||
relaxedCur->setPaddings(paddings);
|
||||
if (!isRelaxed) {
|
||||
// string msg = "====== END rule5RangeRelaxation: Relaxation not
|
||||
// found\n"; dbg(msg);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// next searching step
|
||||
string detailedMsg;
|
||||
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
|
||||
const auto &[v, a] = cur->getVarRange(IterationType::Loop, i);
|
||||
const auto &[_, b] = relaxedCur->getVarRange(IterationType::Loop, i);
|
||||
if (a != b) {
|
||||
detailedMsg += v->getName();
|
||||
detailedMsg +=
|
||||
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
|
||||
detailedMsg +=
|
||||
" (" + to_string(b.first) + "," + to_string(b.second) + "),";
|
||||
}
|
||||
}
|
||||
nextStep(origin, depth, rCur, relaxedCur, detailedMsg);
|
||||
return relaxedCur;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,57 @@
|
|||
#include "nnet/Pass/Rule6KenerlMatching.h"
|
||||
#include "nnet/Visitor/InputVisitor.h"
|
||||
#include "nnet/Visitor/PatternMatcher.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule6KenerlMatching::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
// Build wrapper stages for enforce axis starts from 0
|
||||
PatternMatcher patternMatcher(derivator, cur);
|
||||
cur = patternMatcher.getOffsetCur();
|
||||
|
||||
// Match matchable routines
|
||||
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
|
||||
auto targetOp = idToRoutineType(i);
|
||||
// During guided search, only check the target OP
|
||||
if (derivator.getTargetOp() != RoutineType::NoneType &&
|
||||
derivator.getTargetOp() != targetOp)
|
||||
continue;
|
||||
auto replaces =
|
||||
patternMatcher.matchWithPattern(cur, getPattern(targetOp));
|
||||
replaces = patternMatcher.applyWrapper(replaces);
|
||||
for (auto newCur : replaces)
|
||||
nextStep(origin, depth, rCur, newCur);
|
||||
}
|
||||
{ // Match element-wise OP
|
||||
auto replaces = matchElementWise(cur);
|
||||
if (!replaces.empty())
|
||||
dbg(rCur);
|
||||
for (auto newCur : replaces)
|
||||
nextStep(origin, depth, rCur, newCur);
|
||||
}
|
||||
}
|
||||
|
||||
VecExpr Rule6KenerlMatching::matchElementWise(const RangeOp &rangeOp) {
|
||||
// If the stage is compute bound, then do not convert it.
|
||||
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
|
||||
int64_t inputSize = rangeOp->getInputSize(rangeOp);
|
||||
if (double(flops) / (inputSize + outputSize) > 3)
|
||||
return {};
|
||||
|
||||
vector<int> newShape;
|
||||
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||
if (range.first != 0) {
|
||||
nnet_unimplemented_continue();
|
||||
return {};
|
||||
}
|
||||
newShape.emplace_back(range.second - range.first);
|
||||
}
|
||||
const auto &inputs = InputVisitor().getInputs(rangeOp);
|
||||
auto source =
|
||||
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
|
||||
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
|
||||
return {newTensor};
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,78 @@
|
|||
#include "nnet/Pass/Rule7DLT.h"
|
||||
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||
#include "nnet/dlt.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule7DLT::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
if (!cur)
|
||||
return;
|
||||
auto op = as<BinaryOpNode>(cur->getSummand());
|
||||
if (!op)
|
||||
return;
|
||||
auto subs = {op->getLhs(), op->getRhs()};
|
||||
for (auto subExpr : subs) {
|
||||
auto sub = as<SubscriptNode>(subExpr);
|
||||
if (!sub)
|
||||
continue;
|
||||
auto tensor = as<TensorNode>(sub->getObject());
|
||||
if (!tensor)
|
||||
continue;
|
||||
// // HACK for G2BMM
|
||||
// if (tensor->getDims() != 3)
|
||||
// continue;
|
||||
for (const auto factor : getFactors()) {
|
||||
for (int targetDim = 0; targetDim < tensor->getDims();
|
||||
++targetDim) {
|
||||
if (tensor->getShape(targetDim) % factor)
|
||||
continue;
|
||||
// Debug hint for G2BMM
|
||||
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
|
||||
if (tensor->getShape(targetDim) != 10000)
|
||||
continue;
|
||||
assert(targetDim == 1);
|
||||
}
|
||||
DLT dlt;
|
||||
dlt.split(targetDim, factor);
|
||||
vector<int> newOrder(tensor->getDims() + 1);
|
||||
for (int i = 0; i < tensor->getDims() + 1; ++i)
|
||||
newOrder[i] = i;
|
||||
newOrder[targetDim]++;
|
||||
newOrder[targetDim + 1]--;
|
||||
dlt.reorder(newOrder);
|
||||
dlt.merge(targetDim, targetDim + 1);
|
||||
if (auto opt = dlt.apply(cur, sub, newTensorName())) {
|
||||
Expr newSummand = ReplaceNodeMutator().replace(
|
||||
cur->getSummand(), sub.get(), *opt);
|
||||
auto newCur = buildDLTSingleRangeOp(cur, newSummand);
|
||||
|
||||
// next searching step
|
||||
string msg = "====== END rule7DLT\n";
|
||||
dbg(msg);
|
||||
nextStep(origin, depth, rCur, newCur);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expr Rule7DLT::buildDLTSingleRangeOp(const RangeOp &original,
|
||||
const Expr &newSummand) {
|
||||
auto rangeOp = make_ref<RangeOpNode>(*original);
|
||||
rangeOp->setSummand(newSummand);
|
||||
return rangeOp;
|
||||
}
|
||||
|
||||
vector<int> Rule7DLT::getFactors() {
|
||||
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
|
||||
return {4};
|
||||
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
|
||||
return {3, 4};
|
||||
} else {
|
||||
nnet_unimplemented_halt();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,317 @@
|
|||
#include "nnet/Pass/Rule8GuidedDLT.h"
|
||||
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
static int bitCount(unsigned int n) {
|
||||
int count = 0;
|
||||
while (n != 0) {
|
||||
n = n & (n - 1);
|
||||
count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
static int bitPosition(unsigned int n) {
|
||||
assert(bitCount(n) == 1);
|
||||
int ret = 0;
|
||||
for (n >>= 1; n; n >>= 1)
|
||||
++ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Rule8GuidedDLT::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
guidedDLT(origin, depth, rCur);
|
||||
}
|
||||
|
||||
VecExpr Rule8GuidedDLT::guidedDLT(Formula &origin, int depth, Expr &rCur,
|
||||
bool debug) {
|
||||
string detailedMsg;
|
||||
VecExpr ret;
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
// check cur satisfies T1[A]*T2[B]
|
||||
|
||||
if (!statisfyGuidedDLT(cur))
|
||||
return ret;
|
||||
|
||||
IteratorTable exprIT;
|
||||
if (!exprIT.analyzeExpr(cur))
|
||||
return ret;
|
||||
exprIT.buildTableWithDefaultMap();
|
||||
|
||||
bool setTargetOpHere = false;
|
||||
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
|
||||
// if not correctly unset this variable
|
||||
assert(setTargetOpHere == false);
|
||||
// If the guide direction is set
|
||||
if (derivator.getTargetOp() != RoutineType::NoneType &&
|
||||
idToRoutineType(i) != derivator.getTargetOp())
|
||||
continue;
|
||||
// Warning: no continue befor unset the targetOp
|
||||
if (derivator.getTargetOp() == RoutineType::NoneType) {
|
||||
setTargetOpHere = true;
|
||||
derivator.setTargetOp(idToRoutineType(i));
|
||||
}
|
||||
const Pattern &pattern = getPattern(derivator.getTargetOp());
|
||||
auto mismatches = exprIT.matchPatternIT(pattern);
|
||||
// Pruning less possible results
|
||||
|
||||
// std::cout << "mismatches= " << mismatches.size()
|
||||
// << "; setTargetOpHere: " << setTargetOpHere << "; ";
|
||||
// std::cout << "TargetOp = " <<
|
||||
// static_cast<int>(derivator.getTargetOp())
|
||||
// << "; mismatches : ";
|
||||
// for (const auto i : mismatches)
|
||||
// std::cout << static_cast<int>(i.type) << " ";
|
||||
// std::cout << endl;
|
||||
if (mismatches.size() == 0) {
|
||||
derivator.setSearchState(2);
|
||||
nextStep(origin, depth, rCur, rCur);
|
||||
derivator.setSearchState(1);
|
||||
}
|
||||
if (mismatches.size() > 0 && mismatches.size() <= 2) {
|
||||
for (const auto &mismatch : mismatches) {
|
||||
Expr newCur;
|
||||
if (mismatch.type == MismatchType::MoreVar) {
|
||||
newCur = guidedDLTMoreVar2(cur, mismatch, exprIT, pattern);
|
||||
detailedMsg += "guidedDLTMoreVar2 ";
|
||||
} else if (mismatch.type == MismatchType::DLMismatch ||
|
||||
mismatch.type == MismatchType::OutputDLMismatch) {
|
||||
if (mismatches.size() > 1) {
|
||||
nnet_unimplemented_continue();
|
||||
break;
|
||||
}
|
||||
newCur =
|
||||
guidedDLTDLMismatch(cur, mismatch, exprIT, pattern);
|
||||
detailedMsg += "guidedDLTDLMismatch ";
|
||||
}
|
||||
// std::cout << "newCur= "
|
||||
// << ((newCur == nullptr) ? "Nullptr"
|
||||
// : newCur->toReadable())
|
||||
// << endl;
|
||||
if (!newCur)
|
||||
continue;
|
||||
if (debug)
|
||||
ret.emplace_back(newCur);
|
||||
// next searching step
|
||||
detailedMsg = "Toward " +
|
||||
getPatternName(derivator.getTargetOp()) + ". " +
|
||||
detailedMsg;
|
||||
nextStep(origin, depth, rCur, newCur, detailedMsg);
|
||||
}
|
||||
}
|
||||
// Unset targetOp
|
||||
if (setTargetOpHere) {
|
||||
derivator.setTargetOp(RoutineType::NoneType);
|
||||
setTargetOpHere = false;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Expr Rule8GuidedDLT::guidedDLTDLMismatch(
|
||||
const RangeOp &cur, const Mismatch &mismatch,
|
||||
[[maybe_unused]] const IteratorTable &exprIT, const Pattern &pattern) {
|
||||
assert(mismatch.type == MismatchType::DLMismatch ||
|
||||
mismatch.type == MismatchType::OutputDLMismatch);
|
||||
// Currently only deal with ouput DLT
|
||||
if (mismatch.bitmap != pattern.getNumInputs()) {
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
vector<VarRangePair> newVarRanges;
|
||||
for (const auto &[var, _] : pattern.getRangeOp()->getLoopVarRanges()) {
|
||||
const auto &iterInExpr = mismatch.mappingIter_r.at(var);
|
||||
newVarRanges.emplace_back(cur->getVarRange(iterInExpr));
|
||||
}
|
||||
auto inner = make_ref<RangeOpNode>(*cur);
|
||||
inner->setLoopIterator(newVarRanges);
|
||||
auto subscriptedInner =
|
||||
ReplaceKit::buildSubscirptForLoopVarReplace(inner, {});
|
||||
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||
return outer;
|
||||
}
|
||||
|
||||
bool Rule8GuidedDLT::statisfyGuidedDLT(RangeOp cur) const {
|
||||
auto mul = as<BinaryOpNode>(cur->getSummand());
|
||||
if (!mul)
|
||||
return false;
|
||||
if (mul->getOpType() != OpType::Mul)
|
||||
return false;
|
||||
return as<SubscriptNode>(mul->getLhs()) && as<SubscriptNode>(mul->getRhs());
|
||||
}
|
||||
|
||||
Expr Rule8GuidedDLT::guidedDLTMoreVar2(const RangeOp &cur,
|
||||
const Mismatch &mismatch,
|
||||
const IteratorTable &exprIT,
|
||||
const Pattern &pattern) {
|
||||
int bitmap = mismatch.bitmap;
|
||||
const auto &mergedItersDefaultOrder = exprIT.getPosTable(bitmap);
|
||||
|
||||
// Assure vars only appear in one input tensor
|
||||
int bitmapOfInputs = bitmap & ((1 << exprIT.getNumInputs()) - 1);
|
||||
if (bitCount(bitmapOfInputs) > 1)
|
||||
return nullptr;
|
||||
if (pattern.getPosTable(bitmap).size() != 1) {
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
if (mergedItersDefaultOrder.size() < 1)
|
||||
return nullptr;
|
||||
int tensorID = bitPosition(bitmapOfInputs);
|
||||
if (!checkElementsHaveOnlyOneAccessIteratorSet(exprIT, tensorID))
|
||||
return nullptr;
|
||||
vector<Var> oldVars; // i_1, ...
|
||||
vector<Var> newVars; // j_1, ...
|
||||
VecExpr psis; // i_1=\psi_1(j_1, ...)
|
||||
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
|
||||
vector<VarRangePair> newVarRanges;
|
||||
|
||||
auto originalTensor = exprIT.getTensor(tensorID);
|
||||
auto originalSub = exprIT.getSubscript(tensorID);
|
||||
vector<bool> mergedDims(originalTensor->getDims());
|
||||
|
||||
// Heuristic: merge iters according to their appearance positions
|
||||
std::multimap<int, Var> sortedMergedIters;
|
||||
for (const auto &iter : mergedItersDefaultOrder) {
|
||||
vector<int> dims = exprIT.getIterDimInTensor(tensorID, iter);
|
||||
assert(dims.size() == 1);
|
||||
sortedMergedIters.emplace(dims[0], iter);
|
||||
}
|
||||
vector<Var> mergedIters; // decides the order of fused dims
|
||||
for (const auto &[_, v] : sortedMergedIters)
|
||||
mergedIters.emplace_back(v);
|
||||
|
||||
// Add the merged iterators
|
||||
const auto newVar = getNewVar();
|
||||
newVars.emplace_back(newVar);
|
||||
int newRange = 1;
|
||||
for (const auto &iter : mergedIters) {
|
||||
oldVars.emplace_back(iter);
|
||||
auto range = cur->getRange(iter);
|
||||
newRange *= (range.second - range.first);
|
||||
// if (range.first == 0)
|
||||
// nnet_unimplemented_halt();
|
||||
}
|
||||
newVarRanges.emplace_back(newVar, Range{0, newRange});
|
||||
// Add psis for each old iterator
|
||||
int remainingRange = newRange;
|
||||
Expr phi = nullptr;
|
||||
for (const auto &iter : mergedIters) {
|
||||
auto oldVar = iter;
|
||||
auto range = cur->getRange(iter);
|
||||
int len = (range.second - range.first);
|
||||
remainingRange /= len;
|
||||
Expr psi = newVar;
|
||||
if (remainingRange > 1)
|
||||
psi = psi / remainingRange;
|
||||
if (newRange > remainingRange * len)
|
||||
psi = psi % len;
|
||||
int start = cur->getRange(iter).first;
|
||||
if (start != 0)
|
||||
psi = psi + start;
|
||||
psis.emplace_back(psi);
|
||||
phi = phi + remainingRange * (oldVar - start);
|
||||
}
|
||||
Replace replace{.iteratorType = IterationType::Loop,
|
||||
.oldIters = oldVars,
|
||||
.newIters = newVars,
|
||||
.phis = VecExpr{phi},
|
||||
.psis = psis,
|
||||
.newVarRanges = newVarRanges};
|
||||
// HACK: decide the rebuild data shape order
|
||||
// TODO: get a partial iter mapping and permutate them?
|
||||
vector<Var> tensorDimAxes{newVars};
|
||||
vector<int> newShape;
|
||||
for (const auto &[var, range] : newVarRanges)
|
||||
newShape.emplace_back(range.second - range.first);
|
||||
for (int row = 0; row < exprIT.getNumRows(); ++row) {
|
||||
// Deal with other dimensions of the current tensor
|
||||
if (row == bitmap || ((row & (1 << tensorID)) == 0))
|
||||
continue;
|
||||
using StrideIter = tuple<int, int, Iterator>;
|
||||
vector<StrideIter> strideIters;
|
||||
|
||||
for (size_t i = 0; i < exprIT.getPosTable(row).size(); ++i) {
|
||||
const auto &iter = exprIT.getPosTable(row)[i];
|
||||
const Range range = cur->getRange(iter);
|
||||
const int len = range.second - range.first;
|
||||
|
||||
// HACK Sort according to original stride. (keep original order)
|
||||
strideIters.emplace_back(-exprIT.getStridesInTensor(iter, tensorID),
|
||||
len, iter);
|
||||
|
||||
// // HACK for conv
|
||||
// if (iter == "n")
|
||||
// strideIters.emplace_back(2, len, iter);
|
||||
// else if (iter == "c")
|
||||
// strideIters.emplace_back(1, len, iter);
|
||||
// else
|
||||
// strideIters.emplace_back(0, len, iter);
|
||||
}
|
||||
// HACK: Assure the order of iterators
|
||||
std::sort(strideIters.begin(), strideIters.end(),
|
||||
ref_value_less<StrideIter>);
|
||||
for (const auto &[_, len, oldIter] : strideIters) {
|
||||
const auto &oldVar = oldIter;
|
||||
tensorDimAxes.emplace_back(oldVar);
|
||||
newShape.emplace_back(len);
|
||||
}
|
||||
}
|
||||
|
||||
// build DLT source
|
||||
const auto sourceExpr =
|
||||
buildGuidedDLTSource(originalSub, replace, tensorDimAxes, newShape);
|
||||
const auto sourceRoutine = make_ref<ElementWiseNode>(
|
||||
sourceExpr, vector<Tensor>{originalTensor}, newShape);
|
||||
// build stage connections
|
||||
const auto newTensor =
|
||||
makeTensor(newTensorName(), newShape, {}, sourceRoutine);
|
||||
const auto &newSub = makeSubscript(
|
||||
newTensor, VecExpr(tensorDimAxes.begin(), tensorDimAxes.end()));
|
||||
// TODO [1124]: get variable mapping and reorder L according to it
|
||||
// dbg(cur, originalSub, newSub, newVarRanges, replace.toReadable(),
|
||||
// tensorDimAxes, newShape);
|
||||
|
||||
// Replace the entire subscript(A[xxxxx,xxx]) in the summand
|
||||
Expr newSummand = ReplaceNodeMutator().replace(cur->getSummand(),
|
||||
originalSub.get(), newSub);
|
||||
auto inner = ReplaceKit::replaceRangeOpIterator(cur, replace, newSummand);
|
||||
auto subscriptedInner =
|
||||
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
|
||||
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
|
||||
return outer;
|
||||
}
|
||||
|
||||
bool Rule8GuidedDLT::checkElementsHaveOnlyOneAccessIteratorSet(
|
||||
const IteratorTable &exprIT, int tensorID) {
|
||||
const auto &strideInDim = exprIT.getStrideInDim();
|
||||
for (const auto &strideForOneDim : strideInDim[tensorID]) {
|
||||
vector<pair<int, int>> strideLengthPairs;
|
||||
for (const auto &[iter, s] : strideForOneDim) {
|
||||
const auto &range = exprIT.getRangeOp()->getRange(iter);
|
||||
strideLengthPairs.emplace_back(s, range.second - range.first);
|
||||
}
|
||||
std::sort(strideLengthPairs.begin(), strideLengthPairs.end());
|
||||
for (size_t i = 0; i < strideLengthPairs.size() - 1; ++i) {
|
||||
const auto &[stride, length] = strideLengthPairs[i];
|
||||
if (stride * length > strideLengthPairs[i + 1].first)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Expr Rule8GuidedDLT::buildGuidedDLTSource(const Subscript &originalSub,
|
||||
Replace replace,
|
||||
vector<Var> tensorDimAxes,
|
||||
vector<int> newShape) {
|
||||
Expr newSub = ReplaceKit::replaceMultipleExprs(
|
||||
originalSub, replace.oldIters, replace.psis, true);
|
||||
vector<VarRangePair> loopVarRangePairs;
|
||||
for (size_t i = 0; i < tensorDimAxes.size(); ++i)
|
||||
loopVarRangePairs.emplace_back(tensorDimAxes[i], pair(0, newShape[i]));
|
||||
return makeRangeOperator(loopVarRangePairs, {}, newSub);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,54 @@
|
|||
#include "nnet/Pass/Rule90TwoStageElementWise.h"
|
||||
#include "nnet/Visitor/InputVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule90TwoStageElementWise::transform(Formula &origin, int depth,
|
||||
Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
{ // Match element-wise OP
|
||||
auto replaces = matchTwoStageElementWise(cur);
|
||||
// if (!replaces.empty())
|
||||
// dbg(rCur);
|
||||
// dbg(replaces);
|
||||
for (auto newCur : replaces)
|
||||
nextStep(origin, depth, rCur, newCur);
|
||||
}
|
||||
}
|
||||
|
||||
VecExpr
|
||||
Rule90TwoStageElementWise::matchTwoStageElementWise(const RangeOp &rangeOp) {
|
||||
// If the stage is compute bound, then do not convert it.
|
||||
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
|
||||
int64_t inputSize = rangeOp->getInputSize(rangeOp);
|
||||
if (double(flops) / (inputSize + outputSize) > 3)
|
||||
return {};
|
||||
auto outerSub = as<SubscriptNode>(rangeOp->getSummand());
|
||||
if (!outerSub)
|
||||
return {};
|
||||
auto innerRangeOp = as<RangeOpNode>(outerSub->getObject());
|
||||
if (!innerRangeOp)
|
||||
return {};
|
||||
auto innerSub = as<SubscriptNode>(innerRangeOp->getSummand());
|
||||
if (!innerSub)
|
||||
return {};
|
||||
auto innerTensor = as<TensorNode>(innerSub->getObject());
|
||||
if (!innerTensor)
|
||||
return {};
|
||||
|
||||
vector<int> newShape;
|
||||
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||
if (range.first != 0) {
|
||||
nnet_unimplemented_continue();
|
||||
return {};
|
||||
}
|
||||
newShape.emplace_back(range.second - range.first);
|
||||
}
|
||||
const auto &inputs = InputVisitor().getInputs(rangeOp);
|
||||
auto source =
|
||||
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
|
||||
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
|
||||
return {newTensor};
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,11 @@
|
|||
#include "nnet/Pass/Rule91MergeStagesWithSum.h"
|
||||
#include "nnet/Pass/Rule4StageMerging.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule91MergeStagesWithSum::transform(Formula &origin, int depth,
|
||||
Expr &rCur) {
|
||||
Rule4StageMerging(derivator).rule4StageMerging(origin, depth, rCur, true);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,45 @@
|
|||
#include "nnet/Pass/Rule9RangeMagnify.h"
|
||||
#include "nnet/Visitor/RangeMagnifyVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void Rule9RangeMagnify::transform(Formula &origin, int depth, Expr &rCur) {
|
||||
auto cur = as<RangeOpNode>(rCur);
|
||||
if (cur->hasPaddings()) {
|
||||
// string msg = "====== END rule9RangeMagnify: Paddings exist \n";
|
||||
// dbg(msg);
|
||||
return;
|
||||
}
|
||||
// HACK for conv5x5
|
||||
vector<VarRangePair> newSumVarRanges;
|
||||
for (const auto &[var, range] : cur->getSumVarRanges()) {
|
||||
if (range.first == 0 && range.second == 5) {
|
||||
newSumVarRanges.emplace_back(
|
||||
var, Range{range.first, (range.second + 2) / 3 * 3});
|
||||
} else
|
||||
newSumVarRanges.emplace_back(var, range);
|
||||
}
|
||||
if (newSumVarRanges.empty())
|
||||
return;
|
||||
auto magnifiedCur = RangeMagnifyVisitor().magnify(cur, newSumVarRanges);
|
||||
if (!magnifiedCur)
|
||||
return;
|
||||
|
||||
// next searching step
|
||||
string msg = "relax iterating ranges ";
|
||||
for (size_t i = 0; i < cur->getSumVarRanges().size(); ++i) {
|
||||
const auto &[v1, a] = cur->getVarRange(IterationType::Sum, i);
|
||||
const auto &[v2, b] = magnifiedCur->getVarRange(IterationType::Sum, i);
|
||||
assert(v1->getName() == v2->getName());
|
||||
if (a != b) {
|
||||
msg += v1->getName();
|
||||
msg +=
|
||||
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
|
||||
msg += " (" + to_string(b.first) + "," + to_string(b.second) + "),";
|
||||
}
|
||||
}
|
||||
nextStep(origin, depth, rCur, magnifiedCur, msg);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,165 @@
|
|||
#include "nnet/Visitor/AsTVMVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
std::string AsTVMVisitor::visit_(const Constant &c) {
|
||||
return std::to_string(c->getValue());
|
||||
}
|
||||
std::string AsTVMVisitor::visit_(const BinaryOp &c) {
|
||||
switch (c->getOpType()) {
|
||||
case OpType::Add:
|
||||
return "(" + dispatch(c->getLhs()) + " + " + dispatch(c->getRhs()) +
|
||||
")";
|
||||
case OpType::Sub:
|
||||
return "(" + dispatch(c->getLhs()) + " - " + dispatch(c->getRhs()) +
|
||||
")";
|
||||
case OpType::Mul:
|
||||
return "(" + dispatch(c->getLhs()) + " * " + dispatch(c->getRhs()) +
|
||||
")";
|
||||
case OpType::Div:
|
||||
return "(" + dispatch(c->getLhs()) + " // " + dispatch(c->getRhs()) +
|
||||
")";
|
||||
case OpType::Mod:
|
||||
return "(" + dispatch(c->getLhs()) + " % " + dispatch(c->getRhs()) +
|
||||
")";
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
std::string AsTVMVisitor::visit_(const Func &c) {
|
||||
switch (c->getFuncType()) {
|
||||
case FuncType::Relu:
|
||||
// TODO: Deduce the dtype
|
||||
return "te.max(" + dispatch(c->getObject()) +
|
||||
", tvm.tir.const(0, 'float32'))";
|
||||
case FuncType::Tanh:
|
||||
return "te.tanh(" + dispatch(c->getObject()) + ")";
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
std::string AsTVMVisitor::visit_(const RangeOp &c) {
|
||||
auto outerStage = curStage;
|
||||
curStage = nStage++;
|
||||
|
||||
std::string stmt;
|
||||
std::string stageName = "s" + std::to_string(curStage);
|
||||
std::vector<std::string> reduceVars;
|
||||
for (auto &&[var, range] : c->getSumVarRanges()) {
|
||||
std::string varName = stageName + "_" + var->getName();
|
||||
stmt += varName + " = " + "te.reduce_axis((" +
|
||||
std::to_string(range.first) + ", " +
|
||||
std::to_string(range.second) + "), name=\"" + varName + "\")\n";
|
||||
reduceVars.emplace_back(varName);
|
||||
pythonVars.emplace_back(varName);
|
||||
}
|
||||
std::vector<int> shape;
|
||||
stmt += stageName + " = te.compute((";
|
||||
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
|
||||
auto &&[var, range] = c->getLoopVarRanges()[i];
|
||||
std::string varName = stageName + "_" + var->getName();
|
||||
offset[varName] = -range.first + c->getPaddings(i);
|
||||
auto len = range.second - range.first + 2 * c->getPaddings(i);
|
||||
stmt += std::to_string(len) + ", ";
|
||||
shape.emplace_back(len);
|
||||
}
|
||||
stmt += "), lambda ";
|
||||
bool first = true;
|
||||
for (auto &&[var, range] : c->getLoopVarRanges()) {
|
||||
std::string varName = stageName + "_" + var->getName();
|
||||
stmt += (first ? "" : ", ") + varName;
|
||||
first = false;
|
||||
}
|
||||
std::string summand = dispatch(c->getSummand());
|
||||
if (!reduceVars.empty()) {
|
||||
summand = "te.sum(" + summand + ", axis=(";
|
||||
for (auto &&var : reduceVars) {
|
||||
summand += var + ", ";
|
||||
}
|
||||
summand += "))";
|
||||
}
|
||||
if (c->hasPaddings()) {
|
||||
std::string guard = "tir.if_then_else(tir.all(";
|
||||
bool first = true;
|
||||
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
|
||||
auto &&[var, range] = c->getLoopVarRanges()[i];
|
||||
std::string varName = stageName + "_" + var->getName();
|
||||
if (auto pad = c->getPaddings(i); pad > 0) {
|
||||
guard += (first ? "" : ", ") + varName +
|
||||
" >= " + std::to_string(range.first) + ", " + varName +
|
||||
" < " + std::to_string(range.second);
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
// TODO: Deduce the dtype
|
||||
guard += "), " + summand + ", tvm.tir.const(0.0, \"float32\"))";
|
||||
summand = guard;
|
||||
}
|
||||
stmt += ": " + summand + ")";
|
||||
stmts += stmt + "\n";
|
||||
|
||||
pythonVars.emplace_back(stageName);
|
||||
output = stageName;
|
||||
outputShape = std::move(shape);
|
||||
curStage = outerStage;
|
||||
return stageName;
|
||||
}
|
||||
std::string AsTVMVisitor::visit_(const Subscript &c) {
|
||||
std::string str = dispatch(c->getObject()) + "[";
|
||||
for (size_t i = 0, n = c->getIndex().size(); i < n; i++) {
|
||||
const auto &idx = c->getIndex()[i];
|
||||
str += (i == 0 ? "" : ", ") + dispatch(idx);
|
||||
if (c->getObject()->getType() == NodeType::RangeOpNodeType) {
|
||||
auto rangeOp = as<RangeOpNode>(c->getObject());
|
||||
str += " - " +
|
||||
std::to_string(rangeOp->getLoopVarRanges()[i].second.first -
|
||||
rangeOp->getPaddings(i));
|
||||
}
|
||||
}
|
||||
str += "]";
|
||||
return str;
|
||||
}
|
||||
std::string AsTVMVisitor::visit_(const Var &c) {
|
||||
std::string stageName = "s" + std::to_string(curStage);
|
||||
std::string varName = stageName + "_" + c->getName();
|
||||
if (offset.count(varName)) {
|
||||
return "(" + varName + " - " + std::to_string(offset.at(varName)) + ")";
|
||||
} else {
|
||||
return varName;
|
||||
}
|
||||
}
|
||||
std::string AsTVMVisitor::visit_(const Tensor &c) {
|
||||
pythonVars.emplace_back(c->getName());
|
||||
inputs.emplace_back(c->getName());
|
||||
inputShapes.emplace_back(c->getShape());
|
||||
std::string stmt = c->getName() + " = te.placeholder((";
|
||||
for (auto &&dim : c->getShape()) {
|
||||
stmt += std::to_string(dim) + ", ";
|
||||
}
|
||||
stmt += "), name='" + c->getName() + "')";
|
||||
stmts += stmt + "\n";
|
||||
return c->getName();
|
||||
}
|
||||
std::string AsTVMVisitor::getStmts() const {
|
||||
std::string ret;
|
||||
|
||||
// Workaround because closure capturing does not work in an `exec`
|
||||
// https://stackoverflow.com/questions/2749655/why-are-closures-broken-within-exec
|
||||
ret += "global ";
|
||||
bool first = true;
|
||||
for (auto &&var : pythonVars) {
|
||||
ret += (first ? "" : ", ") + var;
|
||||
first = false;
|
||||
}
|
||||
ret += "\n";
|
||||
|
||||
ret += stmts;
|
||||
ret += "ret = [" + output;
|
||||
for (auto &&input : inputs) {
|
||||
ret += ", " + input;
|
||||
}
|
||||
ret += "]\n";
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,35 @@
|
|||
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void CheckOOBVisitor::visit_(const Subscript &c) {
|
||||
const auto &objectRanges = c->getObjectRangesWithPaddings();
|
||||
for (size_t dim = 0; dim < c->getDims(); ++dim) {
|
||||
SimplifyExprVisitor simplifier;
|
||||
auto optional = simplifier.getExprRange(c->getIndex(dim), rangeOp);
|
||||
if (!optional.has_value())
|
||||
continue;
|
||||
const Range &exprRange = *optional;
|
||||
if (exprRange.first < objectRanges[dim].first ||
|
||||
exprRange.second > objectRanges[dim].second) {
|
||||
// dbg("OOB detected!", c, dim, exprRange, objectRanges[dim]);
|
||||
// std::cout << "OOB detected! " << c->toReadable() << ", dim=" <<
|
||||
// dim
|
||||
// << ", Range=(" << exprRange.first << ", "
|
||||
// << exprRange.second << "), objRange=("
|
||||
// << objectRanges[dim].first << ", "
|
||||
// << objectRanges[dim].second << ")." << std::endl;
|
||||
detect = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckOOBVisitor::checkRangeOp(const RangeOp &_rangeOp) {
|
||||
detect = false;
|
||||
rangeOp = _rangeOp;
|
||||
dispatch(rangeOp);
|
||||
return detect;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,9 @@
|
|||
#include "nnet/Visitor/CloneMutator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Expr CloneMutator::visit_(const Constant &c) { return c; }
|
||||
Expr CloneMutator::visit_(const Var &c) { return c; }
|
||||
Expr CloneMutator::visit_(const Tensor &c) { return c; }
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,34 @@
|
|||
#include "nnet/Visitor/CompareMultiFormulasVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
bool CompareMultiFormulasVisitor::compare(const VecExpr &roots) {
|
||||
if (roots.empty())
|
||||
return false;
|
||||
vector<RangeOp> rangeOps;
|
||||
for (const auto &root : roots) {
|
||||
if (auto rangeOp = as<RangeOpNode>(root))
|
||||
rangeOps.emplace_back(rangeOp);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
const auto pattern = rangeOps[0];
|
||||
for (auto rangeOp : rangeOps) {
|
||||
if (pattern->getNumOutputDims() != rangeOp->getNumOutputDims()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < pattern->getNumOutputDims(); ++i)
|
||||
if (pattern->getVarRange(0, i).second !=
|
||||
rangeOp->getVarRange(0, i).second) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < pattern->getSumVarRanges().size(); ++i)
|
||||
if (pattern->getVarRange(1, i).second !=
|
||||
rangeOp->getVarRange(1, i).second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,38 @@
|
|||
#include "nnet/Visitor/CountRoutineVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void CountRoutineVisitor::visit_(const Tensor &c) {
|
||||
if (auto routine = c->getSource(); routine) {
|
||||
cnts[routineTypeToId(routine->getType())]++;
|
||||
}
|
||||
ExprTreeVisitor::visit_(c);
|
||||
}
|
||||
|
||||
vector<int> CountRoutineVisitor::count(const Expr &root) {
|
||||
cnts = vector<int>(RoutineTypeCnt, 0);
|
||||
dispatch(root);
|
||||
return cnts;
|
||||
}
|
||||
|
||||
bool CountRoutineVisitor::match(const Expr &root, int nMatmul, int nConv,
|
||||
int nElement, int nSg2bmm,
|
||||
int nLongformerGBMM) {
|
||||
auto opCount = count(root);
|
||||
bool ret = true;
|
||||
if (opCount[routineTypeToId(RoutineType::MatmulNodeType)] != nMatmul)
|
||||
ret = false;
|
||||
if (opCount[routineTypeToId(RoutineType::ConvNodeType)] != nConv)
|
||||
ret = false;
|
||||
if (opCount[routineTypeToId(RoutineType::ElementWiseNodeType)] != nElement)
|
||||
ret = false;
|
||||
if (opCount.at(routineTypeToId(RoutineType::G2bmmNodeType)) != nSg2bmm)
|
||||
ret = false;
|
||||
if (!ret) {
|
||||
auto target =
|
||||
vector<int>{nMatmul, nConv, nSg2bmm, nLongformerGBMM, nElement};
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,58 @@
|
|||
#include "nnet/Visitor/FullPrinterVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void FullPrinterVisitor::visit_(const Tensor &c) {
|
||||
q.emplace_back(c->getName(), c->getSource(), c);
|
||||
}
|
||||
|
||||
string FullPrinterVisitor::print(const Expr &root) {
|
||||
q.clear();
|
||||
std::ostringstream oss;
|
||||
dispatch(root);
|
||||
oss << "==> ROOT\n" << root->toReadable() << "\n";
|
||||
for (size_t i = 0; i < q.size(); ++i) {
|
||||
const auto &[name, routine, tensor] = q[i];
|
||||
oss << "==> " << name << " : ";
|
||||
if (routine) {
|
||||
oss << routine->toReadable() << "\n";
|
||||
if (routine->getExpr()) {
|
||||
oss << routine->getExpr()->toReadable() << "\n";
|
||||
} else
|
||||
oss << "[INFO] Source is nullptr \n";
|
||||
if (!routine->getInputs().empty()) {
|
||||
for (const auto &tensor : routine->getInputs())
|
||||
q.emplace_back(tensor->getName(), tensor->getSource(),
|
||||
tensor);
|
||||
} else if (routine->getExpr())
|
||||
dispatch(routine->getExpr());
|
||||
} else
|
||||
oss << "Input Tensor " << tensor->toOutputShape() << "\n";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const vector<tuple<string, Routine, Tensor>> &
|
||||
FullPrinterVisitor::traverse(const Expr &root) {
|
||||
q.clear();
|
||||
dispatch(root);
|
||||
for (size_t i = 0; i < q.size(); ++i) {
|
||||
const auto &[name, routine, tensor] = q[i];
|
||||
if (routine) {
|
||||
// Matmul after DLT do not modify expression, so inputs has a higher
|
||||
// priority. Some OPs such as DLT have not implement source. Then
|
||||
// use inputs
|
||||
if (!routine->getInputs().empty()) {
|
||||
for (const auto &tensor : routine->getInputs())
|
||||
dispatch(tensor);
|
||||
} else if (routine->getExpr()) {
|
||||
dispatch(routine->getExpr());
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,9 @@
|
|||
#include "nnet/Visitor/GetTensorsVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void GetTensorsVisitor::visit_(const Tensor &c) {
|
||||
tensors.try_emplace(c->getName(), c);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,7 @@
|
|||
#include "nnet/Visitor/InputVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void InputVisitor::visit_(const Tensor &c) { inputs.emplace_back(c); }
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,174 @@
|
|||
#include "nnet/Visitor/Interpreter.h"
|
||||
#include "nnet/Visitor/GetTensorsVisitor.h"
|
||||
#include "nnet/expr.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
using ttype = Interpreter::ttype; // Test data type
|
||||
using rtype = Interpreter::rtype; // Return data type
|
||||
using Position = Interpreter::Position;
|
||||
using Inputs = Interpreter::Inputs;
|
||||
using Iteration = Interpreter::Iteration;
|
||||
|
||||
Inputs Interpreter::genInputStartingFromZero(const RangeOp &range) {
|
||||
Inputs inputs;
|
||||
GetTensorsVisitor getTensorsVisitor;
|
||||
auto tensors = getTensorsVisitor.get(range);
|
||||
|
||||
for (const auto &[name, tensor] : tensors) {
|
||||
auto data = make_ref<vector<int>>(tensor->getSize());
|
||||
for (ssize_t i = 0; i < tensor->getSize(); i++) {
|
||||
data->at(i) = i;
|
||||
}
|
||||
inputs.emplace(name, data);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
Interpreter::Interpreter(RangeOp range, int _verbose)
|
||||
: Interpreter(genInputStartingFromZero(range), _verbose){};
|
||||
|
||||
rtype Interpreter::visit_(const Constant &c) { return c->getValue(); }
|
||||
|
||||
rtype Interpreter::visit_(const BinaryOp &c) {
|
||||
rtype valueL = dispatch(c->getLhs()), valueR = dispatch(c->getRhs());
|
||||
|
||||
switch (c->getOpType()) {
|
||||
case OpType::Add:
|
||||
return valueL + valueR;
|
||||
case OpType::Mul:
|
||||
return valueL * valueR;
|
||||
case OpType::Div:
|
||||
nnet_assert(valueR > 0, "Negative divisor is ill-defeind");
|
||||
return valueL / valueR;
|
||||
case OpType::Mod:
|
||||
nnet_assert(valueR > 0, "Negative divisor is ill-defeind");
|
||||
return valueL % valueR;
|
||||
case OpType::Sub:
|
||||
return valueL - valueR;
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
rtype Interpreter::visit_(const RangeOp &c) {
|
||||
rtype ret = 0;
|
||||
iterations.emplace_back();
|
||||
// loop
|
||||
auto loopRanges = c->getLoopVarRanges();
|
||||
assert(positions.back().size() == loopRanges.size());
|
||||
auto paddings = c->getPaddings();
|
||||
for (int i = 0, iEnd = loopRanges.size(); i < iEnd; i++) {
|
||||
int left = loopRanges[i].second.first;
|
||||
int right = loopRanges[i].second.second;
|
||||
int padding = paddings[i];
|
||||
int element = positions.back()[i];
|
||||
if (0 < padding) {
|
||||
nnet_assert(left - padding <= element, "Out of range");
|
||||
nnet_assert(element < right + padding, "Out of range");
|
||||
if (left <= element && element < right) {
|
||||
iterations.back()[loopRanges[i].first] = positions.back()[i];
|
||||
} else {
|
||||
iterations.pop_back();
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
nnet_assert(left <= element, "Out of range");
|
||||
nnet_assert(element < right, "Out of range");
|
||||
iterations.back()[loopRanges[i].first] = positions.back()[i];
|
||||
}
|
||||
}
|
||||
// sum
|
||||
auto sumVarRanges = c->getSumVarRanges();
|
||||
int nSumIters = sumVarRanges.size();
|
||||
if (0 < nSumIters) {
|
||||
vector<int> sumIterValues(nSumIters);
|
||||
for (const auto &[var, range] : sumVarRanges) {
|
||||
sumIterValues.emplace_back(range.first);
|
||||
nnet_assert(range.first < range.second, "No empty range");
|
||||
}
|
||||
// Enumerate all values of sum iterator
|
||||
do {
|
||||
for (int i = 0; i < nSumIters; i++)
|
||||
iterations.back()[sumVarRanges[i].first] = sumIterValues[i];
|
||||
ret += dispatch(c->getSummand());
|
||||
|
||||
// Increase with carry to enumerate sum iterators
|
||||
sumIterValues[nSumIters - 1]++;
|
||||
for (int i = nSumIters - 1; 0 < i; i--) {
|
||||
if (sumIterValues[i] == sumVarRanges[i].second.second) {
|
||||
sumIterValues[i] = sumVarRanges[i].second.first;
|
||||
sumIterValues[i - 1]++;
|
||||
}
|
||||
}
|
||||
} while (sumIterValues[0] < sumVarRanges[0].second.second);
|
||||
} else {
|
||||
ret += dispatch(c->getSummand());
|
||||
}
|
||||
iterations.pop_back();
|
||||
return ret;
|
||||
}
|
||||
|
||||
rtype Interpreter::visit_(const Subscript &c) {
|
||||
int ret = 0;
|
||||
vector<int> idx;
|
||||
auto sub = c->getIndex();
|
||||
for (int i = 0, iEnd = sub.size(); i < iEnd; i++) {
|
||||
idx.emplace_back(dispatch(sub[i]));
|
||||
}
|
||||
|
||||
auto obj = c->getObject();
|
||||
if (obj->getType() == NodeType::RangeOpNodeType) {
|
||||
positions.emplace_back(idx);
|
||||
ret = dispatch(obj);
|
||||
positions.pop_back();
|
||||
} else if (obj->getType() == NodeType::TensorNodeType) {
|
||||
auto tensor = as<TensorNode>(obj);
|
||||
const auto &data = inputs[tensor->getName()];
|
||||
ret = tensor->getData(data, idx);
|
||||
} else
|
||||
assert(false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
rtype Interpreter::visit_(const Var &c) { return iterations.back()[c]; }
|
||||
|
||||
rtype Interpreter::visit_(const Tensor &c) {
|
||||
nnet_unimplemented_halt();
|
||||
return -1;
|
||||
}
|
||||
|
||||
vector<rtype> Interpreter::interpret(const Expr &expr,
|
||||
const vector<Position> &poses) {
|
||||
vector<rtype> ret;
|
||||
for (const auto &pos : poses) {
|
||||
positions.emplace_back(pos);
|
||||
ret.emplace_back(dispatch(expr));
|
||||
positions.pop_back();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<rtype> Interpreter::interpretUniformSample(const RangeOp &range,
|
||||
int nPoses) {
|
||||
vector<Interpreter::Position> poses;
|
||||
const auto &rangeShape = range->getOutputShape();
|
||||
for (int i = 0; i < nPoses; ++i) {
|
||||
Interpreter::Position pos(range->getNumOutputDims(), 0);
|
||||
ssize_t t = range->getOutputSize() / nPoses * i;
|
||||
for (int j = range->getNumOutputDims() - 1; j >= 0; --j) {
|
||||
int extent = rangeShape[j];
|
||||
pos[j] = t % extent;
|
||||
t /= extent;
|
||||
}
|
||||
poses.emplace_back(pos);
|
||||
}
|
||||
return interpret(range, poses);
|
||||
}
|
||||
|
||||
vector<rtype> Interpreter::interpretAllOutput(const RangeOp &range) {
|
||||
return interpretUniformSample(range, range->getOutputSize());
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,46 @@
|
|||
#include "nnet/Visitor/MatchReshapeVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
bool MatchReshapeVisitor::visit_(const RangeOp &memboundRangeOp) {
|
||||
// Merge nested stages
|
||||
auto rangeOp =
|
||||
as<RangeOpNode>(MergeMemboundMutator({memboundRangeOp}).merge());
|
||||
assert(rangeOp);
|
||||
auto sub = as<SubscriptNode>(rangeOp->getSummand());
|
||||
if (!sub)
|
||||
return false;
|
||||
auto sumRanges = rangeOp->getSumVarRanges();
|
||||
for (auto const &[var, range] : sumRanges) {
|
||||
if (range.second - range.first != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto objectRanges = sub->getObjectRangesWithoutPaddings();
|
||||
const auto indices = sub->getIndex();
|
||||
Expr indexExpr;
|
||||
int stride = 1;
|
||||
for (int i = indices.size() - 1; i >= 0; --i) {
|
||||
indexExpr = indexExpr + indices.at(i) * stride;
|
||||
stride *= (objectRanges.at(i).second - objectRanges.at(i).first);
|
||||
}
|
||||
|
||||
SimplifyExprVisitor simplifyExprVisitor;
|
||||
simplifyExprVisitor.simplify(indexExpr);
|
||||
auto exprStrides = simplifyExprVisitor.getStrides();
|
||||
|
||||
auto varRanges = rangeOp->getLoopVarRanges();
|
||||
stride = 1;
|
||||
// compare strides of variables in RangeOP and index
|
||||
for (auto i = varRanges.rbegin(); i != varRanges.rend(); ++i) {
|
||||
const bool alwaysZero = i->second.first == 0 && i->second.second == 1;
|
||||
if (!alwaysZero && exprStrides[i->first] != stride)
|
||||
return false;
|
||||
stride *= (i->second.second - i->second.first);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,56 @@
|
|||
#include "nnet/Visitor/MatchTableVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
void MatchTableVisitor::visit_(const BinaryOp &c, const Tensor &tensor, int dim,
|
||||
optional<int> stride) {
|
||||
if (c->getOpType() == OpType::Add) {
|
||||
dispatch(c->getLhs(), tensor, dim, stride);
|
||||
dispatch(c->getRhs(), tensor, dim, stride);
|
||||
} else if (c->getOpType() == OpType::Sub) {
|
||||
dispatch(c->getLhs(), tensor, dim, stride);
|
||||
if (stride)
|
||||
*stride = -*stride;
|
||||
dispatch(c->getRhs(), tensor, dim, stride);
|
||||
} else if (c->getOpType() == OpType::Mul) {
|
||||
const optional<int> &lStride = subexprStride[c->getLhs().get()];
|
||||
const optional<int> &rStride = subexprStride[c->getRhs().get()];
|
||||
optional<int> lCurStride =
|
||||
(stride && rStride) ? optional(*stride * *rStride) : nullopt;
|
||||
optional<int> rCurStride =
|
||||
(stride && lStride) ? optional(*stride * *lStride) : nullopt;
|
||||
dispatch(c->getLhs(), tensor, dim, lCurStride);
|
||||
dispatch(c->getRhs(), tensor, dim, rCurStride);
|
||||
} else {
|
||||
hasUnsupportedOp = true;
|
||||
}
|
||||
}
|
||||
|
||||
void MatchTableVisitor::visit_(const Subscript &c, const Tensor &tensor,
|
||||
[[maybe_unused]] int dim,
|
||||
[[maybe_unused]] optional<int> stride) {
|
||||
assert(!tensor); // Should not be set until visit a tensor
|
||||
auto object = as<TensorNode>(c->getObject());
|
||||
assert(object);
|
||||
tensors.emplace_back(object);
|
||||
int currentStride = 1;
|
||||
for (int i = (int)c->getDims() - 1; i >= 0; --i) {
|
||||
this->dispatch(c->getIndex(i), object, i, currentStride);
|
||||
currentStride *= object->getShape(i);
|
||||
}
|
||||
}
|
||||
void MatchTableVisitor::visit_(const Var &c, const Tensor &tensor, int dim,
|
||||
optional<int> stride) {
|
||||
appearance.try_emplace(c);
|
||||
appearance[c].emplace_back(pair(tensor, dim));
|
||||
strideTable[c].emplace_back(tensor.get(), dim, stride.value());
|
||||
}
|
||||
|
||||
void MatchTableVisitor::visit_([[maybe_unused]] const Constant &c,
|
||||
[[maybe_unused]] const Tensor &tensor,
|
||||
[[maybe_unused]] int dim,
|
||||
[[maybe_unused]] optional<int> stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,100 @@
|
|||
#include "nnet/Visitor/MatmulTransposeMutator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
VecExpr MatmulTransposeMutator::transpose(const Tensor &tensor) {
|
||||
assert(tensor->getDims() == 2);
|
||||
const auto matmul = as<MatmulNode>(tensor->getSource());
|
||||
VecExpr ret;
|
||||
for (int i = 1; i < 8; ++i) {
|
||||
// Whether really transpose/swap AB. transa/b are the arguments for gemm
|
||||
const int Atrans = (i & 1) > 0;
|
||||
const int Btrans = (i & 2) > 0;
|
||||
const int ABswap = (i & 4) > 0;
|
||||
|
||||
auto newShape = tensor->getShape();
|
||||
auto newPaddings = tensor->getPaddings();
|
||||
auto [b, m, n, k, transa, transb] = matmul->getArgs();
|
||||
auto inputs = matmul->getInputs();
|
||||
transa ^= Atrans;
|
||||
transb ^= Btrans;
|
||||
// build input transpose
|
||||
if (Atrans)
|
||||
inputs[0] = transposeInput(inputs[0]);
|
||||
if (Btrans)
|
||||
inputs[1] = transposeInput(inputs[1]);
|
||||
if (ABswap) {
|
||||
std::swap(inputs[0], inputs[1]);
|
||||
std::swap(m, n);
|
||||
std::swap(transa, transb);
|
||||
std::swap(newShape[0], newShape[1]);
|
||||
std::swap(newPaddings[0], newPaddings[1]);
|
||||
transa ^= 1;
|
||||
transb ^= 1;
|
||||
}
|
||||
// build new Gemm Routine and Tensor
|
||||
// HACK: trivially wrap the source to generate different hash
|
||||
auto _va = make_ref<VarNode>("transA");
|
||||
auto _vb = make_ref<VarNode>("transB");
|
||||
auto _vc = make_ref<VarNode>("swapAB");
|
||||
auto fakeSub = makeSubscript(matmul->getExpr(), {_va, _vb});
|
||||
auto fakeRangeWrapperForHackHash =
|
||||
makeRangeOperator({{_va, {0, Atrans + 100}},
|
||||
{_vb, {0, Btrans + 100}},
|
||||
{_vc, {0, ABswap + 100}}},
|
||||
{}, fakeSub);
|
||||
Matmul newMatmul =
|
||||
make_ref<MatmulNode>(fakeRangeWrapperForHackHash, inputs[0],
|
||||
inputs[1], b, m, n, k, transa, transb);
|
||||
auto newTensor = makeTensor(derivator.newTensorName(), newShape,
|
||||
newPaddings, newMatmul);
|
||||
// build output transpose
|
||||
if (ABswap) {
|
||||
vector<Var> vars{derivator.getNewVar(), derivator.getNewVar()};
|
||||
auto sub = makeSubscript(newTensor, {vars[1], vars[0]});
|
||||
vector<VarRangePair> loopVRs;
|
||||
// Sicne inputs array may be swaped, use the orignal tensor shape
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
loopVRs.emplace_back(vars[i], Range(0, tensor->getShape(i)));
|
||||
}
|
||||
auto rangeOp = makeRangeOperator(loopVRs, {}, sub);
|
||||
ret.emplace_back(rangeOp);
|
||||
} else
|
||||
ret.emplace_back(newTensor);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Tensor MatmulTransposeMutator::transposeInput(const Tensor &tensor) {
|
||||
Tensor ret;
|
||||
if (auto ew = as<ElementWiseNode>(tensor->getSource())) {
|
||||
auto rangeOp = as<RangeOpNode>(tensor->getSource()->getExpr());
|
||||
assert(rangeOp);
|
||||
assert(rangeOp->getNumOutputDims() == 2);
|
||||
auto loopVRs = rangeOp->getLoopVarRanges();
|
||||
std::swap(loopVRs[0], loopVRs[1]);
|
||||
// If there are paddings, the inner stage paddings should be removed
|
||||
assert(!rangeOp->hasPaddings());
|
||||
// auto paddings = rangeOp->getPaddings();
|
||||
// std::swap(paddings[0], paddings[1]);
|
||||
auto sub = makeSubscript(rangeOp, {loopVRs[1].first, loopVRs[0].first});
|
||||
auto newRangeOp = makeRangeOperator(loopVRs, {}, sub);
|
||||
// ElementWise newElementWise = make_ref<ElementWiseNode>(*ew);
|
||||
auto outputShape = ew->getOutputShape();
|
||||
std::swap(outputShape[0], outputShape[1]);
|
||||
auto newElementWise =
|
||||
make_ref<ElementWiseNode>(newRangeOp, ew->getInputs(), outputShape);
|
||||
|
||||
auto tensorShape = tensor->getShape();
|
||||
auto tensorPaddings = tensor->getPaddings();
|
||||
std::swap(tensorShape[0], tensorShape[1]);
|
||||
std::swap(tensorPaddings[0], tensorPaddings[1]);
|
||||
ret = makeTensor(derivator.newTensorName(), tensorShape, tensorPaddings,
|
||||
newElementWise);
|
||||
// } else if (!tensor->getSource()) {
|
||||
} else
|
||||
nnet_unimplemented_halt();
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,135 @@
|
|||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||
#include "nnet/Visitor/ReplaceVariable.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
|
||||
// FIXME: fix empty expression in membound
|
||||
assert(kernels.size() >= 1);
|
||||
if (checkEmpty()) {
|
||||
if (allowEmptyMembound)
|
||||
return nullptr;
|
||||
else
|
||||
nnet_assert(false, "Empty membound expression");
|
||||
}
|
||||
// Nesting stages
|
||||
auto expr = dispatch(kernels.back());
|
||||
// Fusing stages
|
||||
bool merged = false;
|
||||
do {
|
||||
merged = false;
|
||||
RangeOp curRangeOp;
|
||||
for (Expr *curExpr = &expr;
|
||||
curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
|
||||
auto curRangeOp = as<RangeOpNode>(*curExpr);
|
||||
assert(CheckOOBVisitor().checkRangeOp(curRangeOp) == false);
|
||||
auto summand = curRangeOp->getSummand();
|
||||
if (auto subscriptOp = as<SubscriptNode>(summand)) {
|
||||
if (auto mergedExpr = rule4StageMerging(*curExpr, true)) {
|
||||
// dbg(*curExpr, mergedExpr);
|
||||
*curExpr = mergedExpr;
|
||||
merged = true;
|
||||
break;
|
||||
}
|
||||
curExpr = subscriptOp->getObjectPtr();
|
||||
nnet_assert(*curExpr != nullptr, __LINE__);
|
||||
} else if (auto funcOp = as<FuncNode>(summand)) {
|
||||
// Relu({...}[i,j])
|
||||
curExpr = funcOp->getObject()->getObjectPtr();
|
||||
} else
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
} while (merged);
|
||||
return expr;
|
||||
}
|
||||
|
||||
bool MergeMemboundMutator::checkEmpty() {
|
||||
for (const auto &k : kernels) {
|
||||
if (k == nullptr)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Expr MergeMemboundMutator::visit_(const Tensor &c) {
|
||||
if (curDepth > 0)
|
||||
return dispatch(kernels[--curDepth]);
|
||||
else {
|
||||
// Reach the last tensor, return it to reconstruct the total tree
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
Expr MergeMemboundMutator::rule4StageMerging(Expr &rCur,
|
||||
bool mergeStageWithCalc) {
|
||||
auto rangeOp0 = as<RangeOpNode>(rCur);
|
||||
const Subscript &sub0 = as<SubscriptNode>(rangeOp0->getSummand());
|
||||
if (!sub0)
|
||||
return nullptr;
|
||||
const auto &rangeOp1 = as<RangeOpNode>(sub0->getObject());
|
||||
if (!rangeOp1)
|
||||
return nullptr;
|
||||
const auto &sub1 = as<SubscriptNode>(rangeOp1->getSummand());
|
||||
if (!sub1)
|
||||
return nullptr;
|
||||
// merge stage with calculation only when mergeStageWithCalc=true
|
||||
if (!mergeStageWithCalc && !rangeOp1->getSumVarRanges().empty())
|
||||
return nullptr;
|
||||
// Only propogate paddings in perfect nested dimension
|
||||
if (rangeOp1->hasPaddings()) {
|
||||
auto oldTensor = as<TensorNode>(sub1->getObject());
|
||||
if (!oldTensor) {
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// repalce variables: iters of rangeOp1 repalced by indexes of sub0
|
||||
map<string, pair<Expr, Expr>> varMapping;
|
||||
assert(sub0->getDims() == rangeOp1->getLoopVarRanges().size());
|
||||
for (size_t i = 0; i < sub0->getDims(); ++i) {
|
||||
varMapping[rangeOp1->getLoopVar(i)->getName()] =
|
||||
pair(rangeOp1->getLoopVar(i), sub0->getIndex(i));
|
||||
}
|
||||
ReplaceVariable replaceVariable{varMapping};
|
||||
auto merged = make_ref<RangeOpNode>(*rangeOp0);
|
||||
merged->setSummand(replaceVariable(sub1));
|
||||
// a naive approach to propogate paddings
|
||||
if (rangeOp1->hasPaddings()) {
|
||||
auto oldTensor = as<TensorNode>(sub1->getObject());
|
||||
auto newTensor = make_ref<TensorNode>(*oldTensor);
|
||||
for (int i = 0; i < rangeOp1->getNumOutputDims(); ++i) {
|
||||
if (rangeOp1->getPaddings(i) == 0)
|
||||
continue;
|
||||
auto loopVar = rangeOp1->getLoopVar(i);
|
||||
// FIXME: in fact this var should not appear in other index as well,
|
||||
// which may result in OOB
|
||||
bool findSingleVarAsIndex = false;
|
||||
for (size_t subIndexID = 0; subIndexID < sub1->getDims();
|
||||
++subIndexID) {
|
||||
auto index = sub1->getIndex(subIndexID);
|
||||
if (auto indexVar = as<VarNode>(index);
|
||||
indexVar && (indexVar->equal(loopVar))) {
|
||||
newTensor->setPadding(subIndexID,
|
||||
newTensor->getPadding(subIndexID) +
|
||||
rangeOp1->getPaddings(i));
|
||||
findSingleVarAsIndex = true;
|
||||
}
|
||||
}
|
||||
if (!findSingleVarAsIndex) {
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
merged = as<RangeOpNode>(
|
||||
ReplaceNodeMutator().replace(merged, oldTensor.get(), newTensor));
|
||||
assert(merged != nullptr);
|
||||
}
|
||||
// Merge inner stage sums
|
||||
if (!rangeOp1->getSumVarRanges().empty())
|
||||
merged->setSumIterator(rangeOp1->getSumVarRanges());
|
||||
return merged;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,124 @@
|
|||
#include "nnet/Visitor/PatternMatcher.h"
|
||||
#include "nnet/ReplaceKit.h"
|
||||
#include "nnet/Visitor/MatmulTransposeMutator.h"
|
||||
#include "nnet/Visitor/ReplaceVariable.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
PatternMatcher::PatternMatcher(Derivator &derivator, const RangeOp &cur)
|
||||
: Functor(false), derivator(derivator), hasNonZeroRange(false),
|
||||
originalCur(cur) {
|
||||
for (const auto &[var, range] : cur->getLoopVarRanges())
|
||||
if (range.first != 0)
|
||||
hasNonZeroRange = true;
|
||||
}
|
||||
|
||||
RangeOp PatternMatcher::getOffsetCur() {
|
||||
if (!hasNonZeroRange)
|
||||
return originalCur;
|
||||
vector<Var> itersFromNonZero;
|
||||
VecExpr psis;
|
||||
vector<VarRangePair> newLoopVarRanges;
|
||||
for (const auto &[var, range] : originalCur->getLoopVarRanges()) {
|
||||
if (range.first == 0) {
|
||||
newLoopVarRanges.emplace_back(var, range);
|
||||
} else {
|
||||
auto newVar = derivator.getNewVar();
|
||||
newLoopVarRanges.emplace_back(newVar,
|
||||
pair(0, range.second - range.first));
|
||||
itersFromNonZero.emplace_back(var);
|
||||
psis.emplace_back(newVar + range.first);
|
||||
}
|
||||
}
|
||||
auto newSummand = ReplaceKit::replaceMultipleExprs(
|
||||
originalCur->getSummand(), itersFromNonZero, psis);
|
||||
return makeRangeOperator(newLoopVarRanges, originalCur->getSumVarRanges(),
|
||||
newSummand);
|
||||
}
|
||||
|
||||
VecExpr PatternMatcher::matchKernel(const Pattern &pattern,
|
||||
const RangeOp &rangeOp,
|
||||
IteratorTable &exprIT) {
|
||||
VecExpr ret;
|
||||
if (pattern.getNumTensors() != (int)exprIT.getNumTensors())
|
||||
return ret;
|
||||
|
||||
// Whether enable tensor permutation
|
||||
if (false) {
|
||||
const int nInputs = pattern.getNumInputs();
|
||||
vector<int> tensorMap; // [tensors Index] -> pattern tensor ID
|
||||
for (int i = 0; i < nInputs; ++i)
|
||||
tensorMap.emplace_back(i);
|
||||
do {
|
||||
exprIT.buildTable(tensorMap);
|
||||
auto matched = matchKernelWithTensorMap(pattern, rangeOp, exprIT);
|
||||
if (matched)
|
||||
ret.emplace_back(matched);
|
||||
} while (std::next_permutation(tensorMap.begin(), tensorMap.end()));
|
||||
} else {
|
||||
exprIT.buildTableWithDefaultMap();
|
||||
auto matched = matchKernelWithTensorMap(pattern, rangeOp, exprIT);
|
||||
if (matched)
|
||||
ret.emplace_back(matched);
|
||||
}
|
||||
// Generate 8 variants of gemm
|
||||
if (true) // Disabled for debug
|
||||
if (!ret.empty() && dynamic_cast<const MatmulPattern *>(&pattern)) {
|
||||
auto tensor = as<TensorNode>(ret[0]);
|
||||
auto transposeds =
|
||||
MatmulTransposeMutator(derivator).transpose(tensor);
|
||||
for (const auto &transposed : transposeds)
|
||||
ret.emplace_back(transposed);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Expr PatternMatcher::matchKernelWithTensorMap(const Pattern &pattern,
|
||||
const RangeOp &rangeOp,
|
||||
IteratorTable &exprIT) {
|
||||
auto mismatches = exprIT.matchPatternIT(pattern);
|
||||
if (!mismatches.empty())
|
||||
return nullptr;
|
||||
|
||||
const auto &[tensorMap_r, iterToRange_r] = exprIT.getReverseMap();
|
||||
// // TODO: check OOB error
|
||||
// for (int tensorID = 0; tensorID < pattern.getNumInputs(); ++tensorID) {
|
||||
// if (!checkIndexOutOfBound(pattern.getIterInTensorDim(tensorID),
|
||||
// tensorMap_r[tensorID], iterToRange_r))
|
||||
// return nullptr;
|
||||
// }
|
||||
|
||||
// matched! build expr for ret;
|
||||
return pattern.buildExpr(rangeOp, tensorMap_r, iterToRange_r,
|
||||
derivator.newTensorName(), exprIT);
|
||||
}
|
||||
|
||||
VecExpr PatternMatcher::applyWrapper(const VecExpr &exprs) {
|
||||
if (!hasNonZeroRange)
|
||||
return exprs;
|
||||
VecExpr ret, indexes;
|
||||
for (const auto &[var, range] : originalCur->getLoopVarRanges()) {
|
||||
if (range.first == 0) {
|
||||
indexes.emplace_back(var);
|
||||
} else {
|
||||
hasNonZeroRange = true;
|
||||
indexes.emplace_back(var - range.first);
|
||||
}
|
||||
}
|
||||
for (auto &expr : exprs) {
|
||||
auto newSub = makeSubscript(expr, indexes);
|
||||
ret.emplace_back(makeRangeOperator(originalCur->getLoopVarRanges(), {},
|
||||
newSub, originalCur->getPaddings()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
VecExpr PatternMatcher::matchWithPattern(const RangeOp &rangeOp,
|
||||
const Pattern &pattern) {
|
||||
IteratorTable exprIT;
|
||||
if (!exprIT.analyzeExpr(rangeOp))
|
||||
return {};
|
||||
return matchKernel(pattern, rangeOp, exprIT);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,56 @@
|
|||
#include "nnet/Visitor/RangeMagnifyVisitor.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Expr RangeMagnifyVisitor::visit_(const RangeOp &c) {
|
||||
if (auto expr = Mutator::visit_(c)) {
|
||||
auto ret = as<RangeOpNode>(expr);
|
||||
ret->setSumIterator(newSumVarRanges);
|
||||
return ret;
|
||||
} else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Expr RangeMagnifyVisitor::visit_(const Subscript &c) {
|
||||
auto tensor = as<TensorNode>(c->getObject());
|
||||
if (!tensor)
|
||||
return nullptr;
|
||||
// Check new ranges
|
||||
bool paddingMagnify = false;
|
||||
vector<Range> tensorRanges = c->getObjectRangesWithPaddings();
|
||||
vector<int> paddingsDelta(tensorRanges.size(), 0);
|
||||
for (int i = 0; i < (int)c->getDims(); ++i) {
|
||||
auto indexRange =
|
||||
SimplifyExprVisitor().getExprRange(c->getIndex(i), newRangeOp);
|
||||
if (!indexRange.has_value())
|
||||
return nullptr;
|
||||
int delta = max(tensorRanges[i].first - indexRange->first,
|
||||
indexRange->second - tensorRanges[i].second);
|
||||
if (delta > 0) {
|
||||
paddingMagnify = true;
|
||||
paddingsDelta[i] = delta;
|
||||
}
|
||||
}
|
||||
if (!paddingMagnify)
|
||||
return nullptr;
|
||||
// Create new tensor. Direct add paddings to the Tensor.
|
||||
auto newTensor = make_ref<TensorNode>(*tensor);
|
||||
for (int i = 0; i < newTensor->getDims(); ++i)
|
||||
newTensor->setPadding(i, newTensor->getPadding(i) + paddingsDelta[i]);
|
||||
auto newSub = make_ref<SubscriptNode>(*c);
|
||||
newSub->setObject(newTensor);
|
||||
return newSub;
|
||||
}
|
||||
|
||||
RangeOp
|
||||
RangeMagnifyVisitor::magnify(const RangeOp &root,
|
||||
const vector<VarRangePair> &_newSumVarRanges) {
|
||||
newSumVarRanges = _newSumVarRanges;
|
||||
newRangeOp = make_ref<RangeOpNode>(*root);
|
||||
newRangeOp->setSumIterator(newSumVarRanges);
|
||||
const auto &newCur = as<RangeOpNode>(dispatch(root));
|
||||
return newCur;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,133 @@
|
|||
#include "nnet/Visitor/RangeRelaxFunctor.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
RangeMap RangeRelaxFunctor::visit_(const BinaryOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
if (c->getOpType() == OpType::Mul)
|
||||
return intersectRangeMaps(dispatch(c->getLhs()), dispatch(c->getRhs()));
|
||||
nnet_unimplemented_halt();
|
||||
return RangeMap();
|
||||
}
|
||||
RangeMap RangeRelaxFunctor::visit_(const RangeOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
return dispatch(c->getSummand());
|
||||
}
|
||||
RangeMap RangeRelaxFunctor::visit_(const Subscript &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
RangeMap ret;
|
||||
const auto &tensor = as<TensorNode>(c->getObject());
|
||||
for (size_t i = 0; i < c->getDims(); ++i) {
|
||||
const int padding = tensor->getPadding(i);
|
||||
const int length = tensor->getShape(i);
|
||||
if (!padding)
|
||||
continue;
|
||||
// TODO assert access within padding+length
|
||||
// If the index is a single Var
|
||||
if (auto var = as<VarNode>(c->getIndex(i))) {
|
||||
ret = intersectRangeMaps(ret, {{var, {0, length}}});
|
||||
} else { // If the index is linear expression
|
||||
const auto &[strides, offset] =
|
||||
SimplifyExprVisitor().getStridesConstant(c->getIndex(i));
|
||||
// // Calculate the sum of ranges of all iters in negative value
|
||||
// Range allRange(-offset, -offset);
|
||||
// for (const auto &[iter, stride] : strides) {
|
||||
// auto iterRange = rangeOp->getVarRange(iter).second;
|
||||
// if (stride > 0) {
|
||||
// allRange.first -= stride * (iterRange.second - 1);
|
||||
// allRange.second -= stride * iterRange.first;
|
||||
// } else {
|
||||
// allRange.first += stride * iterRange.first;
|
||||
// allRange.second += stride * (iterRange.second - 1);
|
||||
// }
|
||||
// dbg(iter, stride, iterRange, allRange);
|
||||
// }
|
||||
// dbg(allRange);
|
||||
// // Calculate the meaningful ranges for each iter
|
||||
// for (const auto &[iter, stride] : strides) {
|
||||
// auto iterRange = rangeOp->getVarRange(iter).second;
|
||||
// auto rangeExceptThis{allRange};
|
||||
// if (stride > 0) {
|
||||
// rangeExceptThis.first += stride * (iterRange.second - 1);
|
||||
// rangeExceptThis.second += stride * iterRange.first;
|
||||
// } else {
|
||||
// rangeExceptThis.first -= stride * iterRange.first;
|
||||
// rangeExceptThis.second -= stride * (iterRange.second -
|
||||
// 1);
|
||||
// }
|
||||
// // Meaningful calculation range for current iter
|
||||
// int l, r;
|
||||
// if (stride > 0) {
|
||||
// // l = (0 - rangeExceptThis.second + stride - 1) /
|
||||
// stride;
|
||||
// // r = (length - rangeExceptThis.first) / stride;
|
||||
// l = (0 - rangeExceptThis.second + stride - 1) / stride;
|
||||
// r = (length - 1 - rangeExceptThis.first) / stride + 1;
|
||||
// } else {
|
||||
// nnet_unimplemented_continue();
|
||||
// continue;
|
||||
// }
|
||||
// dbg(iter, stride, iterRange, l, r);
|
||||
// ret = intersectRangeMaps(ret, {{iter, {l, r}}});
|
||||
// }
|
||||
// Calculate the sum of ranges of all iters in negative value
|
||||
Range allRange(offset, offset);
|
||||
for (const auto &[iter, stride] : strides) {
|
||||
auto iterRange = rangeOp->getVarRange(iter).second;
|
||||
if (stride > 0) {
|
||||
allRange.first += stride * iterRange.first;
|
||||
allRange.second += stride * (iterRange.second - 1);
|
||||
} else {
|
||||
allRange.first += stride * (iterRange.second - 1);
|
||||
allRange.second += stride * iterRange.first;
|
||||
}
|
||||
// dbg(iter, stride, iterRange, allRange);
|
||||
}
|
||||
// Calculate the meaningful ranges for each iter
|
||||
for (const auto &[iter, stride] : strides) {
|
||||
auto iterRange = rangeOp->getVarRange(iter).second;
|
||||
auto rangeExceptThis{allRange};
|
||||
if (stride > 0) {
|
||||
rangeExceptThis.first -= stride * iterRange.first;
|
||||
rangeExceptThis.second -= stride * (iterRange.second - 1);
|
||||
} else {
|
||||
rangeExceptThis.first -= stride * (iterRange.second - 1);
|
||||
rangeExceptThis.second -= stride * iterRange.first;
|
||||
}
|
||||
// Meaningful calculation range for current iter
|
||||
int l, r;
|
||||
if (stride > 0) {
|
||||
// l = (0 - rangeExceptThis.second + stride - 1) / stride;
|
||||
// r = (length - rangeExceptThis.first) / stride;
|
||||
l = (0 - rangeExceptThis.second + stride - 1) / stride;
|
||||
r = (length - 1 - rangeExceptThis.first) / stride + 1;
|
||||
} else {
|
||||
nnet_unimplemented_continue();
|
||||
continue;
|
||||
}
|
||||
ret = intersectRangeMaps(ret, {{iter, {l, r}}});
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
RangeMap RangeRelaxFunctor::intersectRangeMaps(const RangeMap &a,
|
||||
const RangeMap &b) {
|
||||
RangeMap ret(a);
|
||||
for (const auto &[k, v] : b) {
|
||||
if (!ret.count(k))
|
||||
ret[k] = v;
|
||||
else {
|
||||
auto const &u = ret[k];
|
||||
ret[k] = {max(u.first, v.first), min(u.second, v.second)};
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,23 @@
|
|||
#include "nnet/Visitor/ReplaceNodeMutator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Expr ReplaceNodeMutator::visit_(const Subscript &c) {
|
||||
if (c.get() == target)
|
||||
return replacement;
|
||||
return Mutator::visit_(c);
|
||||
}
|
||||
Expr ReplaceNodeMutator::visit_(const Tensor &c) {
|
||||
if (c.get() == target)
|
||||
return replacement;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Expr ReplaceNodeMutator::replace(const Expr &root, ExprNode *_target,
|
||||
const Expr &_replace) {
|
||||
target = _target;
|
||||
replacement = _replace;
|
||||
return dispatch(root);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,50 @@
|
|||
#include "nnet/Visitor/ReplaceVariable.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Expr ReplaceVariable::visit_(const BinaryOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
if (auto mutate = match(c); mutate)
|
||||
return mutate;
|
||||
else
|
||||
return Mutator::visit_(c);
|
||||
}
|
||||
|
||||
Expr ReplaceVariable::visit_(const Var &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
if (auto mutate = match(c); mutate)
|
||||
return mutate;
|
||||
else
|
||||
return Mutator::visit_(c);
|
||||
}
|
||||
|
||||
Expr ReplaceVariable::visit_(const RangeOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ReplaceVariable::set(VecExpr _pattern, VecExpr _replacement) {
|
||||
patterns = _pattern;
|
||||
replacements = _replacement;
|
||||
for (size_t i = 0; i < patterns.size(); ++i) {
|
||||
auto hash = patterns[i]->hash();
|
||||
assert(patternHash.count(hash) == 0);
|
||||
patternHash[hash] = i;
|
||||
}
|
||||
}
|
||||
|
||||
Expr ReplaceVariable::match(const Expr &c) {
|
||||
auto hash = c->hash();
|
||||
if (auto it = patternHash.find(hash); it != patternHash.end()) {
|
||||
const auto &i = it->second;
|
||||
if (verbose)
|
||||
dbg("Match", *c, *patterns[i], c->hash());
|
||||
return replacements[i];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,245 @@
|
|||
#include "nnet/Visitor/Serializer.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "nnet/expr.h"
|
||||
#include <fstream>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
int Serializer::id = 0;
|
||||
|
||||
Serializer::Serializer(int _verobse)
|
||||
: Functor(_verobse), jPtr(std::make_unique<json>()), j(*jPtr) {}
|
||||
|
||||
Serializer::~Serializer() = default;
|
||||
|
||||
string Serializer::visit_(const Constant &c) {
|
||||
string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["val"] = c->getValue();
|
||||
return key;
|
||||
}
|
||||
|
||||
string Serializer::visit_(const BinaryOp &c) {
|
||||
string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["opType"] = (int)c->getOpType();
|
||||
j[key]["lhs"] = dispatch(c->getLhs());
|
||||
j[key]["rhs"] = dispatch(c->getRhs());
|
||||
return key;
|
||||
}
|
||||
|
||||
string Serializer::visit_(const RangeOp &c) {
|
||||
string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["paddings"] = c->getPaddings();
|
||||
j[key]["summand"] = dispatch(c->getSummand());
|
||||
for (auto var : c->getLoopVarRanges()) {
|
||||
j[key]["loopVarRanges"][var.first->getName()] = var.second;
|
||||
}
|
||||
for (auto var : c->getSumVarRanges()) {
|
||||
j[key]["sumVarRanges"][var.first->getName()] = var.second;
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
string Serializer::visit_(const Subscript &c) {
|
||||
string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["subExprsNum"] = c->getDims();
|
||||
j[key]["object"] = dispatch(c->getObject());
|
||||
vector<string> indexes;
|
||||
for (auto index : c->getIndex()) {
|
||||
indexes.emplace_back(dispatch(index));
|
||||
}
|
||||
j[key]["indexes"] = indexes;
|
||||
return key;
|
||||
}
|
||||
|
||||
string Serializer::visit_(const Var &c) {
|
||||
string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["name"] = c->getName();
|
||||
return key;
|
||||
}
|
||||
|
||||
string Serializer::visit_(const Tensor &c) {
|
||||
const string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
j[key]["name"] = c->getName();
|
||||
j[key]["shape"] = c->getShape();
|
||||
j[key]["paddings"] = c->getPaddings();
|
||||
const auto &routine = c->getSource();
|
||||
j[key]["source"] = dispatchRoutine(routine);
|
||||
return key;
|
||||
}
|
||||
|
||||
bool Serializer::serialize(const Expr &expr, const string &filePath,
|
||||
const string &msg) {
|
||||
// Metadata
|
||||
j["Version"] = VERSION;
|
||||
j["Msg"] = msg;
|
||||
// Expressions and routines
|
||||
id = 0;
|
||||
dispatch(expr);
|
||||
std::ofstream fout(filePath);
|
||||
fout << std::setw(4) << j << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
string Serializer::dispatchRoutine(const Routine &c) {
|
||||
if (!c)
|
||||
return "-1";
|
||||
const string key = std::to_string(id++);
|
||||
j[key]["type"] = c->getType();
|
||||
|
||||
vector<string> inputs;
|
||||
for (const auto &tensor : c->getInputs())
|
||||
inputs.emplace_back(dispatch(tensor));
|
||||
j[key]["inputs"] = inputs;
|
||||
|
||||
if (const auto &expr = c->getExpr())
|
||||
j[key]["expr"] = dispatch(expr);
|
||||
else
|
||||
j[key]["expr"] = "-1";
|
||||
|
||||
switch (c->getType()) {
|
||||
case RoutineType::NoneType:
|
||||
nnet_unimplemented_halt();
|
||||
break;
|
||||
case RoutineType::MatmulNodeType: {
|
||||
j[key]["args"] = as<MatmulNode>(c)->getArgs();
|
||||
break;
|
||||
}
|
||||
case RoutineType::ConvNodeType:
|
||||
j[key]["args"] = as<ConvNode>(c)->getArgs();
|
||||
break;
|
||||
case RoutineType::G2bmmNodeType:
|
||||
j[key]["args"] = as<G2bmmNode>(c)->getArgs();
|
||||
break;
|
||||
case RoutineType::GbmmNodeType:
|
||||
j[key]["args"] = as<GbmmNode>(c)->getArgs();
|
||||
break;
|
||||
case RoutineType::ElementWiseNodeType: {
|
||||
j[key]["outputShape"] = as<ElementWiseNode>(c)->getOutputShape();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
Expr Serializer::deserialize(const string &filePath) {
|
||||
std::ifstream fin(filePath);
|
||||
fin >> j;
|
||||
assert(j["Version"] == VERSION);
|
||||
return buildExprTree("0");
|
||||
}
|
||||
|
||||
Expr Serializer::buildExprTree(string key) {
|
||||
switch (NodeType(j[key]["type"])) {
|
||||
case NodeType::ConstantNodeType: {
|
||||
return make_ref<ConstantNode>(j[key]["val"]);
|
||||
}
|
||||
case NodeType::BinaryOpNodeType: {
|
||||
auto lhs = buildExprTree(j[key]["lhs"]);
|
||||
auto rhs = buildExprTree(j[key]["rhs"]);
|
||||
return make_ref<BinaryOpNode>(j[key]["opType"], lhs, rhs);
|
||||
}
|
||||
case NodeType::RangeOpNodeType: {
|
||||
vector<VarRangePair> loopIters, sumIters;
|
||||
for (auto &loopIter : j[key]["loopVarRanges"].items()) {
|
||||
loopIters.emplace_back(
|
||||
pair(make_ref<VarNode>(loopIter.key()),
|
||||
pair(loopIter.value()[0], loopIter.value()[1])));
|
||||
}
|
||||
for (auto &sumIter : j[key]["sumVarRanges"].items()) {
|
||||
sumIters.emplace_back(
|
||||
pair(make_ref<VarNode>(sumIter.key()),
|
||||
pair(sumIter.value()[0], sumIter.value()[1])));
|
||||
}
|
||||
auto summand = buildExprTree(j[key]["summand"]);
|
||||
auto paddings = j[key]["paddings"].get<std::vector<int>>();
|
||||
auto rangeOp = makeRangeOperator(loopIters, sumIters, summand);
|
||||
rangeOp->setPaddings(paddings);
|
||||
return rangeOp;
|
||||
}
|
||||
case NodeType::SubscriptNodeType: {
|
||||
auto indexed = buildExprTree(j[key]["object"]);
|
||||
VecExpr subExprs;
|
||||
for (int i = 0, iEnd = j[key]["subExprsNum"]; i < iEnd; i++) {
|
||||
subExprs.emplace_back(buildExprTree(j[key]["indexes"][i]));
|
||||
}
|
||||
return make_ref<SubscriptNode>(indexed, subExprs);
|
||||
}
|
||||
case NodeType::VarNodeType: {
|
||||
return make_ref<VarNode>(j[key]["name"]);
|
||||
}
|
||||
case NodeType::TensorNodeType: {
|
||||
auto source = buildRoutine(j[key]["source"]);
|
||||
return make_ref<TensorNode>(j[key]["name"], j[key]["shape"],
|
||||
j[key]["paddings"], source);
|
||||
}
|
||||
default: {
|
||||
nnet_unimplemented_halt();
|
||||
break;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Routine Serializer::buildRoutine(string key) {
|
||||
if (key == "-1")
|
||||
return nullptr;
|
||||
Expr expr = nullptr;
|
||||
if (j[key]["expr"] != "-1")
|
||||
expr = buildExprTree(j[key]["expr"]);
|
||||
vector<Tensor> inputs;
|
||||
for (const auto &input : j[key]["inputs"])
|
||||
inputs.emplace_back(as<TensorNode>(buildExprTree(input)));
|
||||
|
||||
switch (RoutineType(j[key]["type"])) {
|
||||
case RoutineType::NoneType:
|
||||
nnet_unimplemented_halt();
|
||||
break;
|
||||
case RoutineType::MatmulNodeType: {
|
||||
assert(inputs.size() == 2);
|
||||
auto args = j[key]["args"].get<MatmulArgs>();
|
||||
auto ctorArgs =
|
||||
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
|
||||
return make_ref_from_tuple<MatmulNode>(ctorArgs);
|
||||
break;
|
||||
}
|
||||
case RoutineType::ConvNodeType: {
|
||||
assert(inputs.size() == 2);
|
||||
auto args = j[key]["args"].get<ConvArgs>();
|
||||
auto ctorArgs =
|
||||
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
|
||||
return make_ref_from_tuple<ConvNode>(ctorArgs);
|
||||
break;
|
||||
}
|
||||
case RoutineType::G2bmmNodeType: {
|
||||
auto args = j[key]["args"].get<G2bmmArgs>();
|
||||
auto ctorArgs =
|
||||
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
|
||||
return make_ref_from_tuple<G2bmmNode>(ctorArgs);
|
||||
break;
|
||||
}
|
||||
case RoutineType::GbmmNodeType: {
|
||||
auto args = j[key]["args"].get<GbmmArgs>();
|
||||
auto ctorArgs =
|
||||
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
|
||||
return make_ref_from_tuple<GbmmNode>(ctorArgs);
|
||||
break;
|
||||
}
|
||||
case RoutineType::ElementWiseNodeType: {
|
||||
return make_ref<ElementWiseNode>(expr, inputs, j[key]["outputShape"]);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,160 @@
|
|||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
#include "nnet/Visitor/StrideVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
pair<PtrMap<Iterator, int>, int>
|
||||
SimplifyExprVisitor::getStridesConstant(const Expr &expr) {
|
||||
constant = 0;
|
||||
divStrides.clear();
|
||||
modStrides.clear();
|
||||
subexprStride = StrideVisitor().getExprStride(expr);
|
||||
strides.clear();
|
||||
dispatch(expr, 1);
|
||||
return {strides, constant};
|
||||
}
|
||||
|
||||
optional<Range> SimplifyExprVisitor::getExprRange(const Expr &expr,
|
||||
const RangeOp &rangeOp) {
|
||||
constant = 0;
|
||||
divStrides.clear();
|
||||
modStrides.clear();
|
||||
subexprStride = StrideVisitor().getExprStride(expr);
|
||||
strides.clear();
|
||||
dispatch(expr, 1);
|
||||
// Skip if there is divide and modulo
|
||||
if (!divStrides.empty() || !modStrides.empty() || !divExprStrides.empty() ||
|
||||
!modExprStrides.empty())
|
||||
return {};
|
||||
Range ret{constant, constant + 1};
|
||||
for (const auto &[iter, stride] : strides) {
|
||||
const auto &[l, r] = rangeOp->getRange(iter);
|
||||
if (stride > 0) {
|
||||
ret.first += l * stride;
|
||||
ret.second += (r - 1) * stride;
|
||||
} else {
|
||||
ret.first += (r - 1) * stride;
|
||||
ret.second += l * stride;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
PtrMap<Iterator, int> SimplifyExprVisitor::getStrides(const Expr &expr) {
|
||||
return getStridesConstant(expr).first;
|
||||
}
|
||||
|
||||
int SimplifyExprVisitor::getConstant(const Expr &expr) {
|
||||
return getStridesConstant(expr).second;
|
||||
}
|
||||
|
||||
Expr SimplifyExprVisitor::simplify(const Expr &expr) {
|
||||
getStrides(expr);
|
||||
Expr ret = nullptr;
|
||||
// merge divide and modulo items
|
||||
for (const auto &[iterDividerPair, divStride] : divStrides) {
|
||||
const auto &[iter, mod] = iterDividerPair;
|
||||
// mod < 0 is a marker for merging vars with negtive strides. In math,
|
||||
// divider < 0 is not well-defined for mod, so it should be exist in our
|
||||
// epxrs and is only a temporary state which must be simpilified now.
|
||||
if (mod < 0) { // must perfectly merged.
|
||||
const auto &modStride = modStrides[iterDividerPair];
|
||||
assert(divStride / abs(mod) == modStride);
|
||||
assert(divStride > 0);
|
||||
strides.try_emplace(iterDividerPair.first, 0);
|
||||
strides[iterDividerPair.first] += abs(divStride / mod);
|
||||
modStrides.erase(iterDividerPair);
|
||||
} else if (divStride % mod == 0 && modStrides.count(iterDividerPair)) {
|
||||
const auto &modStride = modStrides[iterDividerPair];
|
||||
if (divStride / mod == modStride) {
|
||||
strides.try_emplace(iterDividerPair.first, 0);
|
||||
strides[iterDividerPair.first] += divStride / mod;
|
||||
modStrides.erase(iterDividerPair);
|
||||
} else
|
||||
ret = ret + divStride * (iterDividerPair.first /
|
||||
iterDividerPair.second);
|
||||
} else
|
||||
ret = ret +
|
||||
divStride * (iterDividerPair.first / iterDividerPair.second);
|
||||
}
|
||||
// remaining modulo items
|
||||
for (const auto &[iterDividerPair, stride] : modStrides) {
|
||||
ret = ret + stride * (iterDividerPair.first % iterDividerPair.second);
|
||||
}
|
||||
// normal constant*variable items
|
||||
for (const auto &[iter, stride] : strides) {
|
||||
if (stride == 0)
|
||||
continue;
|
||||
Expr subexpr;
|
||||
if (stride == 1)
|
||||
subexpr = iter;
|
||||
else
|
||||
subexpr = stride * iter;
|
||||
ret = (ret) ? ret + subexpr : subexpr;
|
||||
}
|
||||
// not perfectly nested divide and modulo items
|
||||
for (const auto &[iterDividerPair, stride] : divExprStrides) {
|
||||
ret = ret + stride * (iterDividerPair.first / iterDividerPair.second);
|
||||
}
|
||||
for (const auto &[iterDividerPair, stride] : modExprStrides) {
|
||||
ret = ret + stride * (iterDividerPair.first % iterDividerPair.second);
|
||||
}
|
||||
ret = ret + constant;
|
||||
return ret ? ret : make_ref<ConstantNode>(0);
|
||||
}
|
||||
|
||||
void SimplifyExprVisitor::visit_(const BinaryOp &c, optional<int> stride) {
|
||||
if (verbose)
|
||||
dbg(c);
|
||||
if (c->getOpType() == OpType::Add) {
|
||||
dispatch(c->getLhs(), stride);
|
||||
dispatch(c->getRhs(), stride);
|
||||
} else if (c->getOpType() == OpType::Sub) {
|
||||
dispatch(c->getLhs(), stride);
|
||||
if (stride)
|
||||
*stride = -*stride;
|
||||
dispatch(c->getRhs(), stride);
|
||||
} else if (c->getOpType() == OpType::Mul) {
|
||||
const optional<int> &lStride = subexprStride[c->getLhs().get()];
|
||||
const optional<int> &rStride = subexprStride[c->getRhs().get()];
|
||||
optional<int> lCurStride =
|
||||
(stride && rStride) ? optional(*stride * *rStride) : nullopt;
|
||||
optional<int> rCurStride =
|
||||
(stride && lStride) ? optional(*stride * *lStride) : nullopt;
|
||||
dispatch(c->getLhs(), lCurStride);
|
||||
dispatch(c->getRhs(), rCurStride);
|
||||
} else if (c->getOpType() == OpType::Mod) {
|
||||
const auto ¶m = c->getModDivParameter();
|
||||
if (param.has_value()) {
|
||||
modStrides.try_emplace(*param, 0);
|
||||
modStrides[*param] += stride.value();
|
||||
} else {
|
||||
const auto ¶mExpr = c->getModDivExpr();
|
||||
modExprStrides.try_emplace(paramExpr, 0);
|
||||
modExprStrides[paramExpr] += stride.value();
|
||||
}
|
||||
} else if (c->getOpType() == OpType::Div) {
|
||||
const auto ¶m = c->getModDivParameter();
|
||||
if (param.has_value()) {
|
||||
divStrides.try_emplace(*param, 0);
|
||||
divStrides[*param] += stride.value();
|
||||
} else {
|
||||
const auto ¶mExpr = c->getModDivExpr();
|
||||
divExprStrides.try_emplace(paramExpr, 0);
|
||||
divExprStrides[paramExpr] += stride.value();
|
||||
}
|
||||
} else
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
void SimplifyExprVisitor::visit_(const Var &c, optional<int> stride) {
|
||||
if (verbose)
|
||||
dbg(c);
|
||||
strides.try_emplace(c);
|
||||
strides[c] += stride.value();
|
||||
}
|
||||
void SimplifyExprVisitor::visit_(const Constant &c, optional<int> stride) {
|
||||
if (stride.has_value())
|
||||
constant += stride.value() * c->getValue();
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,30 @@
|
|||
#include "nnet/Visitor/SimplifyFormulaMutator.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
Expr SimplifyFormulaMutator::visit_(const Subscript &c) {
|
||||
++nSubscripts;
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
bool modified = false;
|
||||
auto ret = make_ref<SubscriptNode>(*c);
|
||||
for (size_t i = 0; i < ret->getDims(); ++i) {
|
||||
const auto &e = ret->getIndex(i);
|
||||
if (const auto &mutated = SimplifyExprVisitor().simplify(e)) {
|
||||
modified = true;
|
||||
ret->setIndex(i, mutated);
|
||||
}
|
||||
}
|
||||
return (modified) ? ret : nullptr;
|
||||
}
|
||||
|
||||
Expr SimplifyFormulaMutator::simplify(const Expr &expr) {
|
||||
nSubscripts = 0;
|
||||
const auto &ret = dispatch(expr);
|
||||
nnet_assert(nSubscripts > 0,
|
||||
"Subscript NOT found. Use SimplifyFormulaMutator?");
|
||||
return (ret) ? ret : expr;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,59 @@
|
|||
#include "nnet/Visitor/StrideVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
optional<int> StrideVisitor::visit_(const Subscript &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
auto object = as<TensorNode>(c->getObject());
|
||||
assert(object);
|
||||
for (int i = (int)c->getDims() - 1; i >= 0; --i)
|
||||
this->dispatch(c->getIndex(i));
|
||||
return {};
|
||||
}
|
||||
|
||||
optional<int> StrideVisitor::visit_(const Var &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
// assert(subexprStride.count(&c) == 0);
|
||||
subexprStride[c.get()] = {};
|
||||
return {};
|
||||
}
|
||||
|
||||
optional<int> StrideVisitor::visit_(const Constant &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
optional ret{c->getValue()};
|
||||
// assert(subexprStride.count(&c) == 0);
|
||||
subexprStride[c.get()] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
optional<int> StrideVisitor::visit_(const BinaryOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
optional<int> strideL = this->dispatch(c->getLhs());
|
||||
optional<int> strideR = this->dispatch(c->getRhs());
|
||||
if (!strideL || !strideR)
|
||||
return {};
|
||||
optional<int> ret;
|
||||
switch (c->getOpType()) {
|
||||
case OpType::Add:
|
||||
ret = optional(*strideL + *strideR);
|
||||
break;
|
||||
case OpType::Sub:
|
||||
ret = optional(*strideL - *strideR);
|
||||
break;
|
||||
case OpType::Mul:
|
||||
ret = optional(*strideL * *strideR);
|
||||
break;
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
break;
|
||||
}
|
||||
// assert(subexprStride.count(&c) == 0);
|
||||
subexprStride[c.get()] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,156 @@
|
|||
#include "nnet/Visitor/HashVisitor.h"
|
||||
#include "nnet/Visitor/FullPrinterVisitor.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
namespace nnet {
|
||||
|
||||
constexpr int varPrefix = 11027;
|
||||
constexpr int binPrefix = 11047;
|
||||
constexpr int ssPrefix = 11057;
|
||||
constexpr int addPrefix = 11059;
|
||||
constexpr int mulPrefix = 11069;
|
||||
constexpr int vecPrefix = 11071;
|
||||
constexpr int tensorPrefix = 11083;
|
||||
constexpr int valSuffix = 6214587;
|
||||
|
||||
static inline HashType hash(const HashType a, const HashType b) {
|
||||
return (a * 10007 + b + 12345) % 1000000007;
|
||||
}
|
||||
|
||||
static inline HashType hash(const std::string &s) {
|
||||
HashType ret = 0;
|
||||
for (auto c : s)
|
||||
ret = hash(ret, c);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline HashType hash(const OpType c) { return HashType(c); }
|
||||
|
||||
HashType HashVisitor::getHash(const Expr &c) { return dispatch(c); }
|
||||
|
||||
HashType HashVisitor::visit_(const Constant &c) {
|
||||
auto val = c->getValue();
|
||||
return genhash(val, valSuffix);
|
||||
}
|
||||
|
||||
HashType HashVisitor::visit_(const BinaryOp &c) {
|
||||
HashType hasha = dispatch(c->getLhs());
|
||||
HashType hashb = dispatch(c->getRhs());
|
||||
|
||||
if (c->isSwapable()) {
|
||||
if (hasha > hashb) {
|
||||
std::swap(hasha, hashb);
|
||||
}
|
||||
}
|
||||
return hash(binPrefix, hash(hash(c->getOpType()), hash(hasha, hashb)));
|
||||
return 0;
|
||||
}
|
||||
|
||||
HashType hashLoopVar(const int id, const Range &range) {
|
||||
return hash(varPrefix, hash(id, hash(range.first, range.second)));
|
||||
}
|
||||
|
||||
HashType hashSumVar(const Range &range) {
|
||||
return hash(varPrefix, hash(range.first, range.second));
|
||||
}
|
||||
|
||||
HashType HashVisitor::visit_(const RangeOp &c) {
|
||||
// Identify loop variables
|
||||
for (const auto &[var, range] : c->getLoopVarRanges()) {
|
||||
nnet_assert(varHash.find(var) == varHash.end(),
|
||||
"In HashVisiter::RangeOp invalid loop var.");
|
||||
varHash[var] = hashLoopVar(nLoopVars++, range);
|
||||
}
|
||||
|
||||
// Identify sum variables according to range
|
||||
for (const auto &[var, range] : c->getSumVarRanges()) {
|
||||
nnet_assert(varHash.find(var) == varHash.end(),
|
||||
"In HashVisiter::RangeOp invalid sum var.");
|
||||
varHash[var] = hashSumVar(range);
|
||||
}
|
||||
|
||||
auto expr = c->getSummand();
|
||||
return dispatch(expr);
|
||||
}
|
||||
|
||||
HashType HashVisitor::visit_(const Subscript &c) {
|
||||
HashType curHash = ssPrefix;
|
||||
auto obj = c->getObject();
|
||||
if (obj->getType() == NodeType::RangeOpNodeType) {
|
||||
curHash = hash(curHash, HashVisitor().getHash(obj));
|
||||
} else if (obj->getType() == NodeType::TensorNodeType) {
|
||||
// TODO: hash should based on arguments
|
||||
curHash = hash(curHash, dispatch(obj));
|
||||
} else {
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
|
||||
for (const auto &expr : c->getIndex()) {
|
||||
if (expr->getType() == NodeType::BinaryOpNodeType) {
|
||||
HashType tmp = addPrefix;
|
||||
std::vector<std::pair<HashType, HashType>> coefficients;
|
||||
auto seVisitor = SimplifyExprVisitor();
|
||||
auto [c, x] = seVisitor.getStridesConstant(expr);
|
||||
for (const auto &[key, value] : c) {
|
||||
coefficients.emplace_back(varHash[key], value);
|
||||
}
|
||||
for (const auto &[iter, value] : seVisitor.getDivStrides()) {
|
||||
nnet_assert(iter.second != 1, "invalid div expr");
|
||||
coefficients.emplace_back(
|
||||
hash(binPrefix, hash(varHash[iter.first], iter.second)),
|
||||
value);
|
||||
}
|
||||
sort(coefficients.begin(), coefficients.end());
|
||||
tmp = hash(tmp, x);
|
||||
for (const auto &[key, value] : coefficients) {
|
||||
tmp = hash(tmp, hash(mulPrefix, hash(key, value)));
|
||||
}
|
||||
curHash = hash(curHash, tmp);
|
||||
continue;
|
||||
}
|
||||
if (expr->getType() == NodeType::ConstantNodeType) {
|
||||
curHash = hash(curHash, dispatch(expr));
|
||||
continue;
|
||||
}
|
||||
if (expr->getType() == NodeType::VarNodeType) {
|
||||
curHash = hash(curHash, dispatch(expr));
|
||||
continue;
|
||||
}
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
return curHash;
|
||||
}
|
||||
|
||||
HashType hashPadding(const std::vector<int> &pad) {
|
||||
HashType cur = hash(vecPrefix, pad.size());
|
||||
for (const auto &e : pad) {
|
||||
cur = hash(cur, e);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
HashType HashVisitor::visit_(const Tensor &c) {
|
||||
// TODO: remove this
|
||||
// TODO: check if hash name includes padding.
|
||||
|
||||
if (c->getSource() == nullptr) {
|
||||
return hash(tensorPrefix, genhash(c->getName()));
|
||||
}
|
||||
// dbg(c, c->getSource()->getExpr(), FullPrinterVisitor().print(c),
|
||||
// FullPrinterVisitor().print(c->getSource()->getExpr()));
|
||||
// std::cout << "Tensor: " << int(c->getSource()->getExpr()->getType())
|
||||
// << std::endl;
|
||||
// std::cout << "Tensor: " << c->getSource()->getExpr()->toReadable()
|
||||
// << std::endl;
|
||||
return hash(tensorPrefix,
|
||||
HashVisitor().dispatch(c->getSource()->getExpr()));
|
||||
}
|
||||
|
||||
HashType HashVisitor::visit_(const Var &c) {
|
||||
if (varHash.find(c) == varHash.end()) {
|
||||
nnet_unimplemented_halt();
|
||||
return 0;
|
||||
}
|
||||
return varHash[c];
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,529 @@
|
|||
#include "nnet/derivator.h"
|
||||
#include "nnet/Pass/MatchComputationKernel.h"
|
||||
#include "nnet/Pass/MatchMemBoundKernel.h"
|
||||
#include "nnet/Pass/Rule1VariableSplit.h"
|
||||
#include "nnet/Pass/Rule2VariableMerging.h"
|
||||
#include "nnet/Pass/Rule3StageSplit.h"
|
||||
#include "nnet/Pass/Rule4StageMerging.h"
|
||||
#include "nnet/Pass/Rule5RangeRelaxation.h"
|
||||
#include "nnet/Pass/Rule6KenerlMatching.h"
|
||||
#include "nnet/Pass/Rule7DLT.h"
|
||||
#include "nnet/Pass/Rule8GuidedDLT.h"
|
||||
#include "nnet/Pass/Rule90TwoStageElementWise.h"
|
||||
#include "nnet/Pass/Rule91MergeStagesWithSum.h"
|
||||
#include "nnet/Pass/Rule9RangeMagnify.h"
|
||||
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||
#include "nnet/Visitor/CloneMutator.h"
|
||||
#include "nnet/Visitor/CompareMultiFormulasVisitor.h"
|
||||
#include "nnet/Visitor/CountRoutineVisitor.h"
|
||||
#include "nnet/Visitor/FullPrinterVisitor.h"
|
||||
#include "nnet/Visitor/HashVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/Visitor/Serializer.h"
|
||||
#include "nnet/test.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
class SaveStateGuard {
|
||||
Derivator &derivator;
|
||||
|
||||
public:
|
||||
SaveStateGuard(Derivator &derivator, const Expr &origin,
|
||||
const string &ruleName, const string &ruleMsg = "")
|
||||
: derivator(derivator) {
|
||||
derivator.pushIntermediateState(origin);
|
||||
derivator.pushRuleState(ruleName);
|
||||
derivator.pushRuleMsg(ruleMsg);
|
||||
}
|
||||
~SaveStateGuard() {
|
||||
derivator.popIntermediateState();
|
||||
derivator.popRuleState();
|
||||
derivator.popRuleMsg();
|
||||
}
|
||||
};
|
||||
|
||||
#define SetUpStateGuard() \
|
||||
SaveStateGuard __guard(*this, origin.root, __FUNCTION__)
|
||||
|
||||
void Derivator::dfs(Formula &origin, int depth) {
|
||||
guidedSearch(origin, depth);
|
||||
|
||||
if (depth >= maxDepth) {
|
||||
return;
|
||||
}
|
||||
Expr *curExpr = &origin.root;
|
||||
nnet_assert((*curExpr)->getType() == NodeType::RangeOpNodeType, __LINE__);
|
||||
while ((*curExpr)->getType() == NodeType::RangeOpNodeType) {
|
||||
auto curRangeOp = as<RangeOpNode>(*curExpr);
|
||||
checkOOB(curRangeOp);
|
||||
auto summand = curRangeOp->getSummand();
|
||||
if (summand->getType() == NodeType::SubscriptNodeType) {
|
||||
auto subscriptOp = as<SubscriptNode>(summand);
|
||||
if (rule4StageMerging(origin, depth, *curExpr)) {
|
||||
return;
|
||||
}
|
||||
curExpr = subscriptOp->getObjectPtr();
|
||||
nnet_assert(*curExpr != nullptr, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (summand->getType() == NodeType::BinaryOpNodeType) {
|
||||
if (cntAppliedRules[1] < 3)
|
||||
rule1VariableSplit(origin, depth, *curExpr); // +1/0
|
||||
rule2VariableMerging(origin, depth, *curExpr); // +1
|
||||
if (cntAppliedRules[3] < 1)
|
||||
rule3StageSplit(origin, depth, *curExpr); // +1
|
||||
rule5RangeRelaxation(origin, depth, *curExpr); // 0
|
||||
rule7DLT(origin, depth, *curExpr);
|
||||
rule9RangeMagnify(origin, depth, *curExpr);
|
||||
return;
|
||||
}
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
// RangeOp curRangeOp;
|
||||
// for (Expr *curExpr = &origin.root;
|
||||
// curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
|
||||
// checkOOB(curRangeOp);
|
||||
// auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
|
||||
// // isSimplyNested: a directly nested stage
|
||||
// bool isSimplyNested = (subscript &&
|
||||
// subscript->isRangeOpSubscripted()); if (rule4StageMerging(origin,
|
||||
// depth, *curExpr))
|
||||
// return;
|
||||
// // For the next nested stage
|
||||
// curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
|
||||
// }
|
||||
|
||||
// int stage = 0;
|
||||
// for (Expr *curExpr = &origin.root;
|
||||
// curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
|
||||
// stage++;
|
||||
// // isSimplyNested: a directly nested stage
|
||||
// auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
|
||||
// bool isSimplyNested = (subscript &&
|
||||
// subscript->isRangeOpSubscripted());
|
||||
|
||||
// // TODO recover it
|
||||
// // permuteRangeOps(origin, depth, *curExpr);
|
||||
// // extractSubexpression(origin, depth, *curExpr);
|
||||
|
||||
// rule4StageMerging(origin, depth, *curExpr);
|
||||
|
||||
// if (!isSimplyNested) {
|
||||
// std::cout << "num stage: " << depth << " " << stage << std::endl;
|
||||
// if (depth < 5) {
|
||||
// rule1VariableSplit(origin, depth, *curExpr); // +1/0
|
||||
// rule3StageSplit(origin, depth, *curExpr); // +1
|
||||
// rule2VariableMerging(origin, depth, *curExpr); // +1
|
||||
// rule5RangeRelaxation(origin, depth, *curExpr); // 0
|
||||
// rule9RangeMagnify(origin, depth, *curExpr);
|
||||
// }
|
||||
// if (depth >= 5) {
|
||||
// rule1VariableSplit(origin, depth, *curExpr); // +1/0
|
||||
// rule3StageSplit(origin, depth, *curExpr); // +1
|
||||
// rule2VariableMerging(origin, depth, *curExpr); // +1
|
||||
// rule5RangeRelaxation(origin, depth, *curExpr); // 0
|
||||
// rule6KenerlMatching(origin, depth, *curExpr); // -1
|
||||
// rule7DLT(origin, depth, *curExpr); // +1
|
||||
// rule8GuidedDLT(origin, depth, *curExpr); //
|
||||
// rule9RangeMagnify(origin, depth, *curExpr);
|
||||
// }
|
||||
// }
|
||||
// // For the next nested stage
|
||||
// curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
|
||||
// }
|
||||
}
|
||||
|
||||
Derivator::Derivator(int maxDepth, bool enableHashPruning, LogMode logMode,
|
||||
PassMode passMode)
|
||||
: maxDepth(maxDepth), logMode(logMode), passMode(passMode),
|
||||
enableHashPruning(enableHashPruning), cntAppliedRules(12) {}
|
||||
|
||||
int Derivator::getNumIntermediateStates() { return cntStates; }
|
||||
|
||||
void Derivator::guidedSearch(Formula &origin, int depth) {
|
||||
if (origin.root->getType() == NodeType::TensorNodeType) {
|
||||
auto tensor = as<TensorNode>(origin.root);
|
||||
appendCanddiate(tensor, depth);
|
||||
return;
|
||||
}
|
||||
Expr *expr = &origin.root;
|
||||
nnet_assert((*expr)->getType() == NodeType::RangeOpNodeType, __LINE__);
|
||||
while ((*expr)->getType() == NodeType::RangeOpNodeType) {
|
||||
auto rangeOp = as<RangeOpNode>(*expr);
|
||||
checkOOB(rangeOp);
|
||||
auto summand = rangeOp->getSummand();
|
||||
if (summand->getType() == NodeType::SubscriptNodeType) {
|
||||
auto subscriptOp = as<SubscriptNode>(summand);
|
||||
if (rule4StageMerging(origin, depth, *expr)) {
|
||||
return;
|
||||
}
|
||||
expr = subscriptOp->getObjectPtr();
|
||||
nnet_assert(*expr != nullptr, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (summand->getType() == NodeType::BinaryOpNodeType) {
|
||||
break;
|
||||
}
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
|
||||
if (searchState == 0) {
|
||||
searchState = 1;
|
||||
rule8GuidedDLT(origin, depth, *expr);
|
||||
searchState = 0;
|
||||
return;
|
||||
}
|
||||
if (searchState == 1) {
|
||||
rule8GuidedDLT(origin, depth, *expr);
|
||||
return;
|
||||
}
|
||||
if (searchState == 2) {
|
||||
matchComputationKernel(origin, depth, *expr);
|
||||
return;
|
||||
}
|
||||
if (searchState == 3) {
|
||||
// Pack the remaining computation as a MemBoundOp
|
||||
matchMemBoundKernel(origin, depth, origin.root);
|
||||
return;
|
||||
}
|
||||
nnet_unimplemented_halt();
|
||||
return;
|
||||
}
|
||||
|
||||
void Derivator::ruleBasedDerivate(Formula &origin, int depth) {
|
||||
// string StartDfs = "ruleBasedDerivate dep=" + std::to_string(depth) +
|
||||
// ", targetOp=" +
|
||||
// std::to_string(routineTypeToId(targetOp));
|
||||
// dbg(StartDfs, origin);
|
||||
auto tensor = as<TensorNode>(origin.root);
|
||||
if (tensor) {
|
||||
appendCanddiate(tensor, depth);
|
||||
return;
|
||||
}
|
||||
if (depth >= (int)rulesOverall.size())
|
||||
return;
|
||||
RangeOp curRangeOp;
|
||||
for (Expr *curExpr = &origin.root;
|
||||
curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
|
||||
checkOOB(curRangeOp);
|
||||
auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
|
||||
// isSimplyNested: a directly nested stage
|
||||
bool isSimplyNested = (subscript && subscript->isRangeOpSubscripted());
|
||||
if (rule4StageMerging(origin, depth, *curExpr))
|
||||
return;
|
||||
// For the next nested stage
|
||||
curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
|
||||
}
|
||||
int stageDepth = 0;
|
||||
for (Expr *curExpr = &origin.root;
|
||||
curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
|
||||
// isSimplyNested: a directly nested stage
|
||||
auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
|
||||
bool isSimplyNested = (subscript && subscript->isRangeOpSubscripted());
|
||||
stageDepth++;
|
||||
|
||||
for (int rule : rulesOverall[depth]) {
|
||||
if (rule == 1)
|
||||
rule1VariableSplit(origin, depth, *curExpr);
|
||||
else if (!isSimplyNested) {
|
||||
if (rule == 2)
|
||||
rule2VariableMerging(origin, depth, *curExpr);
|
||||
else if (rule == 3)
|
||||
rule3StageSplit(origin, depth, *curExpr);
|
||||
else if (rule == 5)
|
||||
rule5RangeRelaxation(origin, depth, *curExpr);
|
||||
else if (rule == 6)
|
||||
rule6KenerlMatching(origin, depth, *curExpr);
|
||||
else if (rule == 7)
|
||||
rule7DLT(origin, depth, *curExpr);
|
||||
else if (rule == 8)
|
||||
rule8GuidedDLT(origin, depth, *curExpr);
|
||||
else if (rule == 9)
|
||||
rule9RangeMagnify(origin, depth, *curExpr);
|
||||
}
|
||||
}
|
||||
// For the next nested stage
|
||||
curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
|
||||
}
|
||||
for (int rule : rulesOverall[depth])
|
||||
if (rule == 90 && stageDepth == 2) // HACK: for (T)Conv2gemm
|
||||
rule90TwoStageElementWise(origin, depth, origin.root);
|
||||
else if (rule == 91 && stageDepth >= 2) // HACK: for TConv2gemm
|
||||
rule91MergeStagesWithSum(origin, depth, origin.root);
|
||||
}
|
||||
|
||||
void Derivator::nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur) {
|
||||
// Count the number of searched states
|
||||
++cntStates;
|
||||
rCur.swap(newCur);
|
||||
|
||||
HashType formulaHash = HashVisitor().getHash(origin.root);
|
||||
if (enableHashPruning) {
|
||||
if (searchState != 2) {
|
||||
if (visited.find(formulaHash) != visited.end()) {
|
||||
rCur.swap(newCur);
|
||||
return;
|
||||
}
|
||||
visited.emplace(formulaHash);
|
||||
}
|
||||
}
|
||||
|
||||
if (searchState > 0) {
|
||||
guidedSearch(origin, depth);
|
||||
} else {
|
||||
searchedMaxDepth = max(searchedMaxDepth, depth + 1);
|
||||
|
||||
if (searchStrategy == Strategy::DFS ||
|
||||
(searchStrategy == Strategy::RuleAndDFS &&
|
||||
depth + 1 >= (ssize_t)rulesOverall.size()))
|
||||
dfs(origin, depth + 1);
|
||||
else
|
||||
ruleBasedDerivate(origin, depth + 1);
|
||||
}
|
||||
rCur.swap(newCur);
|
||||
}
|
||||
|
||||
void Derivator::ruleBasedDFS(Formula &origin, int depth, vector<int> _rules,
|
||||
map<int, vector<Iterator>> _substituteRules,
|
||||
bool searchAfterRules) {
|
||||
SaveStateGuard guard(*this, origin.root, string("Init: ") + __FUNCTION__);
|
||||
searchStrategy = (searchAfterRules) ? Strategy::RuleAndDFS : Strategy::Rule;
|
||||
rulesOverall.clear();
|
||||
for (auto i : _rules)
|
||||
rulesOverall.push_back({i});
|
||||
substituteRules = _substituteRules;
|
||||
ruleBasedDerivate(origin, depth);
|
||||
}
|
||||
|
||||
void Derivator::search(Formula &origin, int depth) {
|
||||
SaveStateGuard guard(*this, origin.root, string("Init: ") + __FUNCTION__);
|
||||
searchStrategy = Strategy::DFS;
|
||||
dfs(origin, depth);
|
||||
}
|
||||
|
||||
void Derivator::print() {
|
||||
std::cout << "[RESULT] Derivator::results: " << candidates.size()
|
||||
<< std::endl;
|
||||
std::cout << "==== DFS candidates (" << candidates.size()
|
||||
<< ")====" << std::endl;
|
||||
for (const auto &f : candidates) {
|
||||
std::cout << f.toReadable() << std::endl;
|
||||
// dbg(f.bfsDepth, f.toReadable());
|
||||
}
|
||||
std::cout << "==== DFS log end ====" << std::endl;
|
||||
}
|
||||
|
||||
string Formula::toReadable() const { return FullPrinterVisitor().print(root); }
|
||||
|
||||
void Derivator::rule1VariableSplit(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[1];
|
||||
Rule1VariableSplit(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[1];
|
||||
}
|
||||
|
||||
void Derivator::rule2VariableMerging(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[2];
|
||||
Rule2VariableMerging(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[2];
|
||||
}
|
||||
|
||||
void Derivator::rule3StageSplit(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[3];
|
||||
Rule3StageSplit(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[3];
|
||||
}
|
||||
|
||||
bool Derivator::rule4StageMerging(Formula &origin, int depth, Expr &rCur,
|
||||
bool mergeStageWithCalc) {
|
||||
++cntAppliedRules[4];
|
||||
Rule4StageMerging pass(*this);
|
||||
pass.setMergeStageWithCalc(mergeStageWithCalc);
|
||||
pass.run(origin, depth, rCur);
|
||||
--cntAppliedRules[4];
|
||||
return pass.isSuccessful();
|
||||
}
|
||||
|
||||
void Derivator::rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[5];
|
||||
Rule5RangeRelaxation(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[5];
|
||||
}
|
||||
|
||||
void Derivator::rule6KenerlMatching(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[6];
|
||||
Rule6KenerlMatching(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[6];
|
||||
}
|
||||
|
||||
void Derivator::rule7DLT(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[7];
|
||||
Rule7DLT(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[7];
|
||||
}
|
||||
|
||||
void Derivator::rule8GuidedDLT(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[8];
|
||||
Rule8GuidedDLT(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[8];
|
||||
}
|
||||
|
||||
void Derivator::rule9RangeMagnify(Formula &origin, int depth, Expr &rCur) {
|
||||
++cntAppliedRules[9];
|
||||
Rule9RangeMagnify(*this).run(origin, depth, rCur);
|
||||
--cntAppliedRules[9];
|
||||
}
|
||||
|
||||
void Derivator::rule90TwoStageElementWise(Formula &origin, int depth,
|
||||
Expr &rCur) {
|
||||
Rule90TwoStageElementWise(*this).run(origin, depth, rCur);
|
||||
}
|
||||
|
||||
void Derivator::rule91MergeStagesWithSum(Formula &origin, int depth,
|
||||
Expr &rCur) {
|
||||
Rule91MergeStagesWithSum(*this).run(origin, depth, rCur);
|
||||
}
|
||||
|
||||
void Derivator::matchComputationKernel(Formula &origin, int depth, Expr &rCur) {
|
||||
MatchComputationKernel(*this).run(origin, depth, rCur);
|
||||
}
|
||||
|
||||
void Derivator::matchMemBoundKernel(Formula &origin, int depth, Expr &rCur) {
|
||||
MatchMemBoundKernel(*this).run(origin, depth, rCur);
|
||||
}
|
||||
|
||||
bool Derivator::stageCombination(MultiFormulas &origin, int depth) {
|
||||
return (CompareMultiFormulasVisitor().compare(origin.roots));
|
||||
}
|
||||
|
||||
Expr Derivator::mergeMemboundStages(VecExpr stages) {
|
||||
auto nested = MergeMemboundMutator(stages).merge();
|
||||
return nested;
|
||||
}
|
||||
|
||||
void Derivator::appendCanddiate(const Tensor &tensor, int depth) {
|
||||
// if (!CountRoutineVisitor().match(tensor, 1, 0, 3))
|
||||
// return;
|
||||
|
||||
candidates.emplace_back(tensor, depth);
|
||||
// dbg("!!!!!!!!!!!!!!!Success!!!!!!!!!!!!!!!");
|
||||
if (enableEquivalenceCheck)
|
||||
checkDerivationEquivalence();
|
||||
// printIntermediateStates();
|
||||
}
|
||||
|
||||
bool Derivator::checkOOB(const RangeOp &rangeOp, bool halt) {
|
||||
// Skip check in NoLog mode
|
||||
if (logMode == LogMode::NoLog)
|
||||
return false;
|
||||
bool hasOOB = CheckOOBVisitor().checkRangeOp(rangeOp);
|
||||
if (hasOOB) {
|
||||
printIntermediateStates();
|
||||
dbg(FullPrinterVisitor().print(rangeOp));
|
||||
if (halt)
|
||||
nnet_assert(0, "Out Of Bound in index!");
|
||||
}
|
||||
return hasOOB;
|
||||
}
|
||||
|
||||
string Derivator::newTensorName() {
|
||||
return "T" + std::to_string(++nTensorNames);
|
||||
}
|
||||
|
||||
Var Derivator::getNewVar() {
|
||||
return make_ref<VarNode>("i" + std::to_string(++nIteratorNames));
|
||||
}
|
||||
|
||||
void Derivator::pushIntermediateState(const Expr &expr) {
|
||||
intermediateStates.emplace_back(CloneMutator().clone(expr));
|
||||
};
|
||||
|
||||
void Derivator::pushRuleState(const string &state) {
|
||||
ruleStates.emplace_back(state);
|
||||
}
|
||||
|
||||
void Derivator::pushRuleMsg(const string &state) {
|
||||
ruleMsgs.emplace_back(state);
|
||||
}
|
||||
|
||||
void Derivator::popIntermediateState() { intermediateStates.pop_back(); }
|
||||
|
||||
void Derivator::popRuleState() { ruleStates.pop_back(); }
|
||||
|
||||
void Derivator::popRuleMsg() { ruleMsgs.pop_back(); }
|
||||
|
||||
RoutineType Derivator::getTargetOp() { return targetOp; }
|
||||
|
||||
void Derivator::setTargetOp(RoutineType _targetOp) { targetOp = _targetOp; }
|
||||
|
||||
int Derivator::getSearchState() { return searchState; }
|
||||
|
||||
void Derivator::setSearchState(int _searchState) { searchState = _searchState; }
|
||||
|
||||
void Derivator::printStatistics() {
|
||||
printf("==== Derivator statistics ====\n");
|
||||
printf("Max Depth = %d\n", maxDepth);
|
||||
printf("searchStrategy = ");
|
||||
if (searchStrategy == Strategy::DFS)
|
||||
printf("DFS\n");
|
||||
else if (searchStrategy == Strategy::Rule)
|
||||
printf("Rule\n");
|
||||
else if (searchStrategy == Strategy::RuleAndDFS)
|
||||
printf("RuleAndDFS\n");
|
||||
printf("enableHashPruning = %s\n", enableHashPruning ? "true" : "false");
|
||||
printf("Reached Max Depth during search = %d\n", searchedMaxDepth);
|
||||
printf("#Candidates = %lu\n", candidates.size());
|
||||
printf("#Intermediate states = %d\n", cntStates);
|
||||
printf("#Hashed intermediate states = %lu\n", visited.size());
|
||||
printf("#Iteratos = %d\n", nIteratorNames);
|
||||
printf("#Tensors = %d\n", nTensorNames);
|
||||
}
|
||||
|
||||
void Derivator::setDumpFirstSuccess(const string &_logFnPrefix) {
|
||||
setEquivalenceCheck();
|
||||
logMode = LogMode::DumpFristCandiate;
|
||||
logFnPrefix = _logFnPrefix;
|
||||
}
|
||||
|
||||
void Derivator::printIntermediateStates() {
|
||||
// Skip in NoLog mode
|
||||
if (logMode == LogMode::NoLog)
|
||||
return;
|
||||
assert(intermediateStates.size() == ruleStates.size());
|
||||
assert(intermediateStates.size() == ruleMsgs.size());
|
||||
for (size_t i = 0; i < intermediateStates.size(); ++i) {
|
||||
string msg = "=== Depth " + std::to_string(i) + " " + ruleStates[i] +
|
||||
": " + ruleMsgs[i];
|
||||
std::cout << msg << endl;
|
||||
std::cout << FullPrinterVisitor().print(intermediateStates[i]) << endl;
|
||||
if (logMode == LogMode::DumpFristCandiate) {
|
||||
Serializer serializer;
|
||||
serializer.serialize(intermediateStates[i],
|
||||
logFnPrefix + to_string(i) + ".expr", msg);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < intermediateStates.size(); ++i) {
|
||||
if (auto cur = as<RangeOpNode>(intermediateStates[i]))
|
||||
if (CheckOOBVisitor().checkRangeOp(cur)) {
|
||||
printf("OOB detected depth=%lu\n", i);
|
||||
}
|
||||
}
|
||||
if (logMode == LogMode::DumpFristCandiate) {
|
||||
puts("Serializaiton finished.");
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
void Derivator::checkDerivationEquivalence() {
|
||||
if (!checkExprsEquvivalence(intermediateStates)) {
|
||||
nnet_assert(0, "Inequivalent derivation");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void Derivator::setEquivalenceCheck() { enableEquivalenceCheck = true; }
|
||||
|
||||
Derivator::PassMode Derivator::getPassMode() { return passMode; }
|
||||
|
||||
Derivator::LogMode Derivator::getLogMode() { return logMode; }
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,131 @@
|
|||
#include "nnet/dlt.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
#include "nnet/visitor.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
optional<Expr> DLT::apply(const RangeOp &rangeOp, const Subscript &subscript,
|
||||
string newTensorName) {
|
||||
const auto &tensor = as<TensorNode>(subscript->getObject());
|
||||
assert(tensor != nullptr);
|
||||
vector<int> shape0(tensor->getShape()), shape1;
|
||||
const auto &paddings = tensor->getPaddings();
|
||||
VecExpr index0 = subscript->getIndex(), index1;
|
||||
// compute new shapes and index
|
||||
for (const auto &opPtr : ops) {
|
||||
if (auto op = as<DLTSplit>(opPtr)) {
|
||||
assert(op->dim < (int)shape0.size());
|
||||
for (int i = 0; i < (int)shape0.size(); ++i)
|
||||
if (i != op->dim) {
|
||||
shape1.emplace_back(shape0[i]);
|
||||
index1.emplace_back(index0[i]);
|
||||
} else {
|
||||
assert(shape0[i] % op->factor == 0);
|
||||
shape1.emplace_back(shape0[i] / op->factor);
|
||||
shape1.emplace_back(op->factor);
|
||||
if (const auto &opt =
|
||||
splitIndex(index0[i], op->factor, rangeOp);
|
||||
opt.has_value()) {
|
||||
index1.emplace_back(opt->first);
|
||||
index1.emplace_back(opt->second);
|
||||
} else
|
||||
return {};
|
||||
}
|
||||
} else if (auto op = as<DLTMerge>(opPtr)) {
|
||||
assert(op->dim0 < (int)shape0.size());
|
||||
assert(op->dim1 < (int)shape0.size());
|
||||
for (int i = 0; i < (int)shape0.size(); ++i)
|
||||
if (i == op->dim0) {
|
||||
shape1.emplace_back(shape0[op->dim0] * shape0[op->dim1]);
|
||||
index1.emplace_back(index0[op->dim0] * shape0[op->dim1] +
|
||||
index0[op->dim1]);
|
||||
} else if (i != op->dim1) {
|
||||
shape1.emplace_back(shape0[i]);
|
||||
index1.emplace_back(index0[i]);
|
||||
}
|
||||
} else if (auto op = as<DLTReorder>(opPtr)) {
|
||||
if (op->dims.size() != shape0.size()) {
|
||||
// TODO: input Reorder should have the same order with tensor
|
||||
nnet_unimplemented_continue();
|
||||
return {};
|
||||
}
|
||||
assert(op->dims.size() == shape0.size());
|
||||
for (size_t i = 0; i < shape0.size(); ++i) {
|
||||
shape1.emplace_back(shape0[op->dims[i]]);
|
||||
index1.emplace_back(index0[op->dims[i]]);
|
||||
}
|
||||
}
|
||||
for (const auto &index : index1) {
|
||||
// Maybe there are bugs...
|
||||
// assert(index != nullptr);
|
||||
if (index == nullptr) {
|
||||
std::cout << "Warning empty" << std::endl;
|
||||
return {};
|
||||
}
|
||||
}
|
||||
shape0.swap(shape1);
|
||||
shape1.clear();
|
||||
index0.swap(index1);
|
||||
index1.clear();
|
||||
}
|
||||
for (auto &index : index0) {
|
||||
// Maybe there are bugs...
|
||||
assert(index != nullptr);
|
||||
if (index == nullptr)
|
||||
return {};
|
||||
index = SimplifyExprVisitor().simplify(index);
|
||||
}
|
||||
// HACK DLT with paddings: transfer original paddings to the new one
|
||||
vector<int> dltedPaddings =
|
||||
(paddings.size() == shape0.size()) ? paddings : vector<int>{};
|
||||
// TODO [necessary] build DLT source expr. Is OP-based DLT is good too?
|
||||
// HACK [important] fix this fake tensor.
|
||||
auto elementRoutine = make_ref<ElementWiseNode>(
|
||||
// FIXME: implement transpose
|
||||
// makeTensor(newTensorName + "_DLT", {}), vector<Tensor>{tensor},
|
||||
// shape0);
|
||||
makeTensor("__DLT", {}), vector<Tensor>{tensor}, shape0);
|
||||
auto dltedTensor =
|
||||
makeTensor(newTensorName, shape0, dltedPaddings, elementRoutine);
|
||||
auto dltedSubscript = makeSubscript(dltedTensor, index0);
|
||||
return optional<Expr>(std::in_place, dltedSubscript);
|
||||
}
|
||||
|
||||
optional<pair<Expr, Expr>> DLT::splitIndex(Expr expr, int factor,
|
||||
RangeOp rangeOp) {
|
||||
auto strides = SimplifyExprVisitor().getStrides(expr);
|
||||
Expr quotient, remainder;
|
||||
for (const auto &[iter, stride] : strides) {
|
||||
const auto &[var, range] = rangeOp->getVarRange(iter);
|
||||
// Add new expr, dealing with the initial empty expr
|
||||
auto addExpr = [](Expr &orig, const Expr &newExpr) {
|
||||
if (!orig)
|
||||
orig = newExpr;
|
||||
else
|
||||
orig = orig + newExpr;
|
||||
};
|
||||
if (abs(stride) >= factor) {
|
||||
if (stride % factor)
|
||||
return {};
|
||||
addExpr(quotient, (stride / factor) * var);
|
||||
} else {
|
||||
if (stride * (range.second - range.first) > factor)
|
||||
return {};
|
||||
addExpr(remainder, stride * var);
|
||||
}
|
||||
}
|
||||
return optional<pair<Expr, Expr>>(std::in_place, quotient, remainder);
|
||||
}
|
||||
|
||||
void DLT::split(int dim, int factor) {
|
||||
ops.emplace_back(make_ref<DLTSplit>(dim, factor));
|
||||
}
|
||||
void DLT::merge(int dim0, int dim1) {
|
||||
ops.emplace_back(make_ref<DLTMerge>(dim0, dim1));
|
||||
}
|
||||
void DLT::reorder(vector<int> dims) {
|
||||
ops.emplace_back(make_ref<DLTReorder>(dims));
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,452 @@
|
|||
#include "nnet/expr.h"
|
||||
#include "nnet/Visitor/GetTensorsVisitor.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
string serializeVec(vector<Expr> v) {
|
||||
if (v.empty())
|
||||
return "[]";
|
||||
return "[" +
|
||||
std::accumulate(v.begin() + 1, v.end(), v[0]->toReadable(),
|
||||
[](const string &a, Expr b) {
|
||||
return a + ',' + b->toReadable();
|
||||
}) +
|
||||
"]";
|
||||
}
|
||||
|
||||
string serializeVec(vector<Var> v) {
|
||||
VecExpr vv;
|
||||
for (const auto &a : v)
|
||||
vv.emplace_back(a);
|
||||
return serializeVec(vv);
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ios, const ExprNode &expr) {
|
||||
ios << expr.toReadable();
|
||||
return ios;
|
||||
}
|
||||
|
||||
TensorNode::TensorNode(string _name, vector<int> _shape, vector<int> _paddings,
|
||||
Routine _source)
|
||||
: name(_name), shape(_shape), paddings(_paddings), source(_source) {
|
||||
if (source && source->getExpr()) {
|
||||
if (auto range = as<RangeOpNode>(source->getExpr()))
|
||||
for (auto [iter, lr] : range->getLoopVarRanges())
|
||||
nnet_assert(lr.first == 0 && lr.second > 0,
|
||||
"Tensor dims should start from 0.");
|
||||
}
|
||||
if (paddings.size() == 0)
|
||||
paddings = vector<int>(shape.size(), 0);
|
||||
assert(paddings.size() == shape.size());
|
||||
}
|
||||
|
||||
string TensorNode::toOutputShape() const {
|
||||
return "shape=" + serializeVec(shape) + " pad=" + serializeVec(paddings);
|
||||
}
|
||||
|
||||
string TensorNode::toReadable() const {
|
||||
string ret = name;
|
||||
string property = "<pad=";
|
||||
bool hasPaddings = false;
|
||||
for (size_t i = 0; i < paddings.size(); ++i) {
|
||||
if (i > 0)
|
||||
property += ",";
|
||||
property += to_string(paddings[i]);
|
||||
if (paddings[i])
|
||||
hasPaddings = true;
|
||||
}
|
||||
property += ">";
|
||||
return (hasPaddings) ? ret + property : ret;
|
||||
}
|
||||
|
||||
int TensorNode::getData(const Ref<vector<int>> &data, const vector<int> &idx) {
|
||||
assert(idx.size() == shape.size());
|
||||
for (size_t i = 0; i < idx.size(); ++i) {
|
||||
if (idx[i] < 0 || idx[i] >= shape[i]) {
|
||||
assert(0 - paddings[i] <= idx[i]);
|
||||
assert(idx[i] < shape[i] + paddings[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return data->at(getOffset(idx));
|
||||
}
|
||||
|
||||
size_t TensorNode::getOffset(const vector<int> &idx) {
|
||||
auto nDim = idx.size();
|
||||
assert(shape.size() == nDim);
|
||||
if (idx.empty()) {
|
||||
return 0;
|
||||
}
|
||||
for (size_t i = 0; i < nDim; i++) {
|
||||
if (idx[i] < 0 || shape[i] <= idx[i]) {
|
||||
return (size_t)-1;
|
||||
}
|
||||
}
|
||||
size_t offset = idx[0];
|
||||
size_t dim = 0;
|
||||
while (++dim < nDim) {
|
||||
offset = offset * shape[dim] + idx[dim];
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
string RangeOpNode::toReadable() const {
|
||||
string ret;
|
||||
for (int i = 0; i < IterationType::NumIterationType; ++i) {
|
||||
ret += (i == Loop) ? "L" : "Sum";
|
||||
for (const auto &kv : vars[i]) {
|
||||
ret += "<" + kv.first->getName() + ":" +
|
||||
std::to_string(kv.second.first) + ":" +
|
||||
std::to_string(kv.second.second) + ">";
|
||||
}
|
||||
if (i == Loop && hasPaddings()) {
|
||||
ret += "<pad=";
|
||||
for (const auto &i : paddings) {
|
||||
ret += to_string(i) + ",";
|
||||
}
|
||||
ret += ">";
|
||||
}
|
||||
}
|
||||
if (auto sub = as<SubscriptNode>(getSummand()); sub) {
|
||||
ret += " ... " + serializeVec(sub->getIndex()) + "\n {" +
|
||||
sub->getObject()->toReadable() + "}";
|
||||
} else {
|
||||
ret += "\n {" + subExprs[Summand]->toReadable() + "}";
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
int RangeOpNode::getNumOutputDims() const { return vars[Loop].size(); }
|
||||
bool RangeOpNode::hasVar(int index, Var name) const {
|
||||
for (const auto &kv : vars[index])
|
||||
if (kv.first->equal(name))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
int RangeOpNode::getVarIndex(int type, string name) {
|
||||
for (size_t i = 0; i < vars[type].size(); ++i)
|
||||
if (vars[type][i].first->equal(name))
|
||||
return i;
|
||||
assert(0);
|
||||
return 0;
|
||||
}
|
||||
Range RangeOpNode::getRange(const Var &var) const {
|
||||
for (const auto &varRanges : vars) {
|
||||
for (const auto &varRange : varRanges) {
|
||||
if (varRange.first->equal(var))
|
||||
return varRange.second;
|
||||
}
|
||||
}
|
||||
nnet_assert(0, "Var is not a iterator.");
|
||||
return Range();
|
||||
}
|
||||
VarRangePair RangeOpNode::getVarRange(const Var &var) const {
|
||||
for (const auto &varRanges : vars) {
|
||||
for (const auto &varRange : varRanges) {
|
||||
if (varRange.first->equal(var))
|
||||
return varRange;
|
||||
}
|
||||
}
|
||||
nnet_assert(0, "Var is not a iterator.");
|
||||
return VarRangePair();
|
||||
}
|
||||
|
||||
void SubscriptNode::setObject(Expr e) {
|
||||
nnet_assert(as<TensorNode>(e) || as<RangeOpNode>(e),
|
||||
"Illegal subscripted object");
|
||||
indexed = e;
|
||||
}
|
||||
|
||||
bool SubscriptNode::isRangeOpSubscripted() const {
|
||||
return as<RangeOpNode>(indexed) != nullptr;
|
||||
}
|
||||
|
||||
vector<Range> SubscriptNode::getObjectRangesWithoutPaddings() const {
|
||||
if (isRangeOpSubscripted()) {
|
||||
return as<RangeOpNode>(indexed)->getOutputRanges();
|
||||
} else {
|
||||
vector<Range> ret;
|
||||
for (const auto &len : as<TensorNode>(indexed)->getShape())
|
||||
ret.emplace_back(0, len);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
vector<Range> SubscriptNode::getObjectRangesWithPaddings() const {
|
||||
vector<Range> ret;
|
||||
if (isRangeOpSubscripted()) {
|
||||
auto rangeOp = as<RangeOpNode>(indexed);
|
||||
ret = rangeOp->getOutputRanges();
|
||||
for (size_t i = 0; i < ret.size(); ++i) {
|
||||
if (int pad = rangeOp->getPaddings(i)) {
|
||||
ret[i].first -= pad;
|
||||
ret[i].second += pad;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto tensor = as<TensorNode>(indexed);
|
||||
for (const auto &len : tensor->getShape())
|
||||
ret.emplace_back(0, len);
|
||||
for (int i = 0; i < tensor->getDims(); ++i) {
|
||||
if (int pad = tensor->getPadding(i)) {
|
||||
ret[i].first -= pad;
|
||||
ret[i].second += pad;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
optional<pair<Iterator, int>> BinaryOpNode::getModDivParameter() const {
|
||||
auto lhs = as<VarNode>(getLhs());
|
||||
auto rhs = as<ConstantNode>(getRhs());
|
||||
if (lhs == nullptr) {
|
||||
return {};
|
||||
}
|
||||
if (lhs->getType() != NodeType::VarNodeType) {
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
if (rhs->getType() != NodeType::ConstantNodeType) {
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
assert(rhs != nullptr);
|
||||
return pair(lhs, rhs->getValue());
|
||||
}
|
||||
|
||||
pair<Expr, int> BinaryOpNode::getModDivExpr() const {
|
||||
auto constant = as<ConstantNode>(getRhs());
|
||||
assert(constant != nullptr);
|
||||
return pair(getLhs(), constant->getValue());
|
||||
}
|
||||
|
||||
string BinaryOpNode::toReadable() const {
|
||||
string ret = "(";
|
||||
ret += subExprs[LHS]->toReadable();
|
||||
ret += " ";
|
||||
ret += opSymbols[static_cast<std::underlying_type_t<OpType>>(opType)];
|
||||
ret += " ";
|
||||
ret += subExprs[RHS]->toReadable();
|
||||
return ret + ")";
|
||||
};
|
||||
|
||||
bool BinaryOpNode::isSwapable() const {
|
||||
switch (getOpType()) {
|
||||
case OpType::Add:
|
||||
case OpType::Mul:
|
||||
return true;
|
||||
case OpType::Sub:
|
||||
case OpType::Div:
|
||||
case OpType::Mod:
|
||||
return false;
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
string SubscriptNode::toReadable() const {
|
||||
string ret;
|
||||
ret += "{";
|
||||
ret += indexed->toReadable();
|
||||
ret += "}[";
|
||||
for (size_t i = 0; i < subExprs.size(); ++i) {
|
||||
ret += subExprs[i]->toReadable();
|
||||
if (i != subExprs.size() - 1)
|
||||
ret += ", ";
|
||||
else
|
||||
ret += "]";
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
string FuncNode::toReadable() const {
|
||||
string ret;
|
||||
if (funcType == FuncType::Relu)
|
||||
ret += "Relu";
|
||||
else if (funcType == FuncType::Tanh)
|
||||
ret += "Tanh";
|
||||
else
|
||||
nnet_unimplemented_halt();
|
||||
ret += "( ... " + serializeVec(object->getIndex()) + ")\n {" +
|
||||
object->getObject()->toReadable() + "}";
|
||||
return ret;
|
||||
}
|
||||
|
||||
Expr operator+(const Expr &lhs, const Expr &rhs) {
|
||||
if (lhs == nullptr && rhs == nullptr)
|
||||
return nullptr;
|
||||
else if (lhs == nullptr)
|
||||
return rhs;
|
||||
else if (rhs == nullptr)
|
||||
return lhs;
|
||||
else
|
||||
return make_ref<BinaryOpNode>(OpType::Add, lhs, rhs);
|
||||
}
|
||||
|
||||
BinaryOp operator-(const Expr &lhs, const Expr &rhs) {
|
||||
return make_ref<BinaryOpNode>(OpType::Sub, lhs, rhs);
|
||||
}
|
||||
|
||||
BinaryOp operator*(const Expr &lhs, const Expr &rhs) {
|
||||
return make_ref<BinaryOpNode>(OpType::Mul, lhs, rhs);
|
||||
}
|
||||
|
||||
BinaryOp operator/(const Expr &lhs, const Expr &rhs) {
|
||||
return make_ref<BinaryOpNode>(OpType::Div, lhs, rhs);
|
||||
}
|
||||
|
||||
BinaryOp operator%(const Expr &lhs, const Expr &rhs) {
|
||||
return make_ref<BinaryOpNode>(OpType::Mod, lhs, rhs);
|
||||
}
|
||||
|
||||
Expr operator+(const Expr &lhs, const int &rhs) {
|
||||
if (lhs != nullptr && rhs != 0)
|
||||
return make_ref<BinaryOpNode>(OpType::Add, lhs,
|
||||
make_ref<ConstantNode>(rhs));
|
||||
else if (lhs == nullptr)
|
||||
return make_ref<ConstantNode>(rhs);
|
||||
else
|
||||
return lhs;
|
||||
}
|
||||
|
||||
Expr operator+(const int &lhs, const Expr &rhs) { return rhs + lhs; }
|
||||
|
||||
Expr operator-(const Expr &lhs, const int &rhs) { return lhs + (-rhs); }
|
||||
|
||||
Expr operator-(const int &lhs, const Expr &rhs) {
|
||||
if (rhs != nullptr)
|
||||
return make_ref<BinaryOpNode>(OpType::Sub, make_ref<ConstantNode>(lhs),
|
||||
rhs);
|
||||
else
|
||||
return make_ref<ConstantNode>(lhs);
|
||||
}
|
||||
|
||||
Expr operator*(const Expr &lhs, const int &rhs) {
|
||||
if (rhs == 1)
|
||||
return lhs;
|
||||
else
|
||||
return make_ref<BinaryOpNode>(OpType::Mul, lhs,
|
||||
make_ref<ConstantNode>(rhs));
|
||||
}
|
||||
|
||||
Expr operator*(const int &lhs, const Expr &rhs) {
|
||||
if (lhs == 1)
|
||||
return rhs;
|
||||
else
|
||||
return make_ref<BinaryOpNode>(OpType::Mul, make_ref<ConstantNode>(lhs),
|
||||
rhs);
|
||||
}
|
||||
|
||||
bool operator==(const Var &lhs, const string &rhs) {
|
||||
return lhs->getName() == rhs;
|
||||
}
|
||||
|
||||
bool operator==(const string &lhs, const Var &rhs) { return rhs == lhs; }
|
||||
Expr operator%(const Expr &lhs, const int rhs) {
|
||||
return make_ref<BinaryOpNode>(OpType::Mod, lhs,
|
||||
make_ref<ConstantNode>(rhs));
|
||||
}
|
||||
Expr operator/(const Expr &lhs, const int rhs) {
|
||||
if (rhs == 1)
|
||||
return lhs;
|
||||
else
|
||||
return make_ref<BinaryOpNode>(OpType::Div, lhs,
|
||||
make_ref<ConstantNode>(rhs));
|
||||
}
|
||||
|
||||
// Wrappers for type deduction
|
||||
Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts) {
|
||||
return make_ref<SubscriptNode>(tensor, subscripts);
|
||||
}
|
||||
|
||||
RangeOp makeRangeOperator(const vector<VarRangePair> &_loopIters,
|
||||
const vector<VarRangePair> &_sumIters, Expr _summand,
|
||||
const vector<int> &paddings) {
|
||||
return make_ref<RangeOpNode>(_loopIters, _sumIters, _summand, paddings);
|
||||
}
|
||||
|
||||
// Wrappers for type deduction
|
||||
Tensor makeTensor(const string &name, const vector<int> &shape,
|
||||
const vector<int> &paddings, const Routine &source) {
|
||||
if (paddings.size() == 0)
|
||||
return make_ref<TensorNode>(name, shape,
|
||||
vector<int>((int)shape.size(), 0), source);
|
||||
else
|
||||
return make_ref<TensorNode>(name, shape, paddings, source);
|
||||
}
|
||||
|
||||
int64_t TensorNode::getSize() const {
|
||||
int64_t size = 1;
|
||||
for (auto len : shape)
|
||||
size *= len;
|
||||
return size;
|
||||
}
|
||||
int RangeOpNode::getPaddings(int dim) const {
|
||||
return dim < (int)paddings.size() ? paddings[dim] : 0;
|
||||
}
|
||||
|
||||
vector<int> RangeOpNode::getPaddings() const {
|
||||
if (paddings.size() > 0)
|
||||
return paddings;
|
||||
else
|
||||
return vector<int>(getNumOutputDims(), 0);
|
||||
}
|
||||
|
||||
void RangeOpNode::setPaddings(vector<int> _paddings) { paddings = _paddings; }
|
||||
|
||||
bool RangeOpNode::hasPaddings() const {
|
||||
for (const auto &p : paddings)
|
||||
if (p > 0)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t RangeOpNode::getFlops() const {
|
||||
int64_t sumCnt = getOutputSize();
|
||||
if (vars[Sum].empty())
|
||||
sumCnt = 0;
|
||||
else
|
||||
for (const auto &[var, range] : getSumVarRanges())
|
||||
sumCnt *= range.second - range.first;
|
||||
return sumCnt;
|
||||
}
|
||||
|
||||
int64_t RangeOpNode::getInputSize(const RangeOp &self) const {
|
||||
int64_t ret = 0;
|
||||
auto tensors = GetTensorsVisitor().get(self);
|
||||
for (const auto &[name, tensor] : tensors)
|
||||
ret += tensor->getSize();
|
||||
return ret;
|
||||
}
|
||||
|
||||
int64_t RangeOpNode::getOutputSize() const {
|
||||
int64_t loopCnt = 1;
|
||||
for (const auto &[var, range] : getLoopVarRanges())
|
||||
loopCnt *= range.second - range.first;
|
||||
return loopCnt;
|
||||
}
|
||||
|
||||
vector<int> RangeOpNode::getOutputShape() const {
|
||||
vector<int> ret;
|
||||
for (const auto &[var, range] : getLoopVarRanges())
|
||||
ret.emplace_back(range.second - range.first);
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<Range> RangeOpNode::getOutputRanges() const {
|
||||
vector<Range> ret;
|
||||
for (const auto &[var, range] : getLoopVarRanges())
|
||||
ret.emplace_back(range);
|
||||
for (size_t i = 0; i < paddings.size(); ++i)
|
||||
if (paddings[i] > 0) {
|
||||
ret[i].first -= paddings[i];
|
||||
ret[i].second += paddings[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void FuncNode::setObject(Expr e) {
|
||||
object = as<SubscriptNode>(e);
|
||||
nnet_assert(object, "Illegal subscripted object");
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,653 @@
|
|||
#include "nnet/iterator_table.h"
|
||||
#include "nnet/Visitor/MatchTableVisitor.h"
|
||||
#include "nnet/Visitor/SimplifyExprVisitor.h"
|
||||
#include "nnet/permutation.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
bool IteratorTable::analyzeExpr(const RangeOp &_rangeOp) {
|
||||
rangeOp = _rangeOp;
|
||||
MatchTableVisitor mtVisitor;
|
||||
if (!mtVisitor(rangeOp))
|
||||
return false;
|
||||
tie(appearance, tensors, strideTable, subscripts) = mtVisitor.getResult();
|
||||
// dbg(appearance, tensors, strideTable);
|
||||
return true;
|
||||
}
|
||||
|
||||
// mapTensors
|
||||
void IteratorTable::buildTable(const vector<int> &_tensorMap) {
|
||||
tensorMap = _tensorMap;
|
||||
tensorMap_r.clear();
|
||||
tensorMap_r.resize(getNumInputs());
|
||||
tensorIDMap_r.clear();
|
||||
tensorIDMap_r.resize(getNumInputs());
|
||||
posTable.clear();
|
||||
posTable.resize(getNumRows());
|
||||
strideInTensor.clear();
|
||||
|
||||
// build reversed index (from tensorID to tensor/original tensor index)
|
||||
for (size_t i = 0; i < tensorMap.size(); ++i) {
|
||||
tensorMap_r[tensorMap[i]] = tensors[i];
|
||||
tensorIDMap_r[tensorMap[i]] = i;
|
||||
}
|
||||
strideInDim.clear();
|
||||
strideInDim.resize(getNumInputs());
|
||||
for (int i = 0; i < getNumInputs(); ++i)
|
||||
strideInDim[i].resize(getTensor(i)->getDims());
|
||||
|
||||
// auxiliary array for calculate in-dim stride
|
||||
vector<vector<int>> ldaInTensors(getNumInputs());
|
||||
for (int i = 0; i < getNumInputs(); ++i) {
|
||||
ldaInTensors[i].resize(getTensor(i)->getDims());
|
||||
ldaInTensors[i].back() = 1;
|
||||
for (int j = getTensor(i)->getDims() - 2; j >= 0; --j)
|
||||
ldaInTensors[i][j] =
|
||||
ldaInTensors[i][j + 1] * getTensor(i)->getShape(j + 1);
|
||||
}
|
||||
|
||||
map<TensorNode *, int> inputTensor2id;
|
||||
for (int i = 0; i < getNumInputs(); ++i)
|
||||
inputTensor2id[tensors[i].get()] = tensorMap[i];
|
||||
|
||||
iterInTensorDim.clear();
|
||||
iterInTensorDim.resize(getNumInputs());
|
||||
for (int i = 0; i < getNumInputs(); ++i) {
|
||||
iterInTensorDim[tensorMap[i]].resize(tensors[i]->getDims());
|
||||
}
|
||||
|
||||
for (const auto &[var, tds] : strideTable) {
|
||||
int bitmap = 0;
|
||||
strideInTensor.emplace(var, getNumInputs());
|
||||
for (const auto &[tensorNode, dim, stride] : tds) {
|
||||
int tensorID = inputTensor2id[tensorNode];
|
||||
int bit = 1 << tensorID;
|
||||
if (std::find_if(iterInTensorDim[tensorID][dim].begin(),
|
||||
iterInTensorDim[tensorID][dim].end(),
|
||||
[var = var](const Var &v) {
|
||||
return v->equal(var);
|
||||
}) == iterInTensorDim[tensorID][dim].end())
|
||||
iterInTensorDim[tensorID][dim].emplace_back(var);
|
||||
bitmap |= bit;
|
||||
if (strideInTensor[var][tensorID] < 0 || stride < 0)
|
||||
strideInTensor[var][tensorID] = -1;
|
||||
else
|
||||
strideInTensor[var][tensorID] += stride;
|
||||
// Update strideInDim
|
||||
assert(stride % ldaInTensors[tensorID][dim] == 0);
|
||||
strideInDim[tensorID][dim][var] =
|
||||
stride / ldaInTensors[tensorID][dim];
|
||||
}
|
||||
if (rangeOp->hasLoopVar(var))
|
||||
bitmap |= 1 << getNumInputs();
|
||||
posTable[bitmap].emplace_back(var);
|
||||
}
|
||||
}
|
||||
|
||||
void IteratorTable::buildTableWithDefaultMap() {
|
||||
vector<int> tensorMap;
|
||||
for (int i = 0; i < getNumInputs(); ++i)
|
||||
tensorMap.emplace_back(i);
|
||||
buildTable(tensorMap);
|
||||
}
|
||||
int IteratorTable::getStridesInTensor(Iterator iter, int tensorID) const {
|
||||
return strideInTensor.at(iter).at(tensorID);
|
||||
}
|
||||
|
||||
vector<int> IteratorTable::getIterDimInTensor(int tensorID,
|
||||
const Iterator &iter) const {
|
||||
vector<int> ret;
|
||||
for (size_t i = 0; i < iterInTensorDim[tensorID].size(); ++i) {
|
||||
for (const auto &it : iterInTensorDim[tensorID][i])
|
||||
if (iter->equal(it))
|
||||
ret.emplace_back(i);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<Mismatch> IteratorTable::matchPatternIT(const Pattern &patternIT) {
|
||||
vector<Mismatch> ret;
|
||||
iterMap.clear();
|
||||
vector<vector<Iterator>> multiExprVar, multiPatternVar;
|
||||
// match iterators in single iterator rows
|
||||
for (int row = 0; row < getNumRows(); ++row) {
|
||||
int nExprVars = posTable[row].size(),
|
||||
nPatternVars = patternIT.posTable[row].size();
|
||||
if (nExprVars < nPatternVars) {
|
||||
ret.emplace_back(MismatchType::LessVar, row);
|
||||
continue;
|
||||
}
|
||||
if (nExprVars > nPatternVars) {
|
||||
ret.emplace_back(MismatchType::MoreVar, row);
|
||||
continue;
|
||||
}
|
||||
if (posTable[row].empty())
|
||||
continue;
|
||||
// prepare for many-to-many iterator mapping
|
||||
if (posTable[row].size() > 1) {
|
||||
multiExprVar.emplace_back(posTable[row]);
|
||||
multiPatternVar.emplace_back(patternIT.posTable[row]);
|
||||
}
|
||||
assert(!iterMap.count(posTable[row][0])); // check NO duplicate mapping
|
||||
if (posTable[row].size() == 1)
|
||||
iterMap[posTable[row][0]] = patternIT.posTable[row][0];
|
||||
}
|
||||
if (!ret.empty())
|
||||
return ret;
|
||||
PermutationGenerator permutationGenerator{multiPatternVar, multiExprVar};
|
||||
bool checked = false;
|
||||
// Permute iterator mappings to find a matched case
|
||||
do {
|
||||
auto mappingIter_r = permutationGenerator.get();
|
||||
for (const auto &[exprIter, patternIter] : iterMap)
|
||||
mappingIter_r[patternIter] = exprIter;
|
||||
auto mismatches = matchPatternITCheckStrides(patternIT, mappingIter_r);
|
||||
// if (mappingIter_r.count("_Conv_c"))
|
||||
// if (mappingIter_r["_Conv_n"] == "n" &&
|
||||
// mappingIter_r["_Conv_c"] == "c" &&
|
||||
// mappingIter_r["_Conv_h"] == "i22" &&
|
||||
// mappingIter_r["_Conv_r"] == "i4" &&
|
||||
// mappingIter_r["_Conv_w"] == "i17" &&
|
||||
// mappingIter_r["_Conv_s"] == "i14") {
|
||||
// dbg(ret.size());
|
||||
// if (mismatches.size() > 0)
|
||||
// dbg(mismatches.size(), mismatches[0].type);
|
||||
// }
|
||||
if (mismatches.size() == 0) { // matched
|
||||
ret = mismatches;
|
||||
// Complete iterator mapping
|
||||
for (const auto &[patternIter, exprIter] : mappingIter_r) {
|
||||
if (iterMap.count(exprIter))
|
||||
assert(iterMap[exprIter]->equal(patternIter));
|
||||
iterMap[exprIter] = patternIter;
|
||||
}
|
||||
break;
|
||||
} else if (!checked) {
|
||||
ret = mismatches;
|
||||
checked = true;
|
||||
} else if ((static_cast<int>(ret[0].type) <
|
||||
static_cast<int>(mismatches[0].type)) ||
|
||||
((static_cast<int>(ret[0].type) ==
|
||||
static_cast<int>(mismatches[0].type)) &&
|
||||
(mismatches.size() < ret.size()))) {
|
||||
ret = mismatches;
|
||||
}
|
||||
} while (permutationGenerator.next());
|
||||
// Build reverse iterator mapping
|
||||
if (ret.empty()) {
|
||||
iterToRange_r.clear();
|
||||
for (const auto &[exprIter, patternIter] : iterMap)
|
||||
iterToRange_r[patternIter] = rangeOp->getVarRange(exprIter);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<Mismatch> IteratorTable::matchPatternITCheckStrides(
|
||||
const Pattern &patternIT, PtrMap<Iterator, Iterator> mappingIter_r) {
|
||||
vector<Mismatch> ret;
|
||||
// Check strides against each stride constraint
|
||||
for (const auto &constraint : patternIT.getStrideConstraints()) {
|
||||
// TODO: supprot PROPOTIONAL constraint
|
||||
auto stride0 = strideInTensor.at(
|
||||
mappingIter_r[constraint.v0])[constraint.tensorID];
|
||||
auto stride1 = strideInTensor.at(
|
||||
mappingIter_r[constraint.v1])[constraint.tensorID];
|
||||
if (stride0 != stride1) {
|
||||
ret.emplace_back(Mismatch(MismatchType::StrideMismatch, -1));
|
||||
}
|
||||
}
|
||||
if (!ret.empty())
|
||||
return ret;
|
||||
// check the appearance of iterators inside tensors.
|
||||
// If mismatch, this can be repaired by guided DLT.
|
||||
for (int tensorID = 0; tensorID < getNumInputs(); ++tensorID) {
|
||||
int exprTensorDim = tensorMap_r[tensorID]->getDims();
|
||||
int patternTensorDim = patternIT.tensorMap_r[tensorID]->getDims();
|
||||
if (exprTensorDim != patternTensorDim) {
|
||||
ret.emplace_back(MismatchType::DLMismatch, tensorID);
|
||||
continue;
|
||||
}
|
||||
[&] {
|
||||
for (int dim = 0; dim < exprTensorDim; ++dim) {
|
||||
// If #iters is differnt, than DLT is required
|
||||
if (strideInDim[tensorID][dim].size() !=
|
||||
patternIT.strideInDim[tensorID][dim].size()) {
|
||||
ret.emplace_back(MismatchType::DLMismatch, tensorID);
|
||||
return;
|
||||
}
|
||||
for (const auto &[patternIter, patternStride] :
|
||||
patternIT.strideInDim[tensorID][dim]) {
|
||||
auto exprIter = mappingIter_r[patternIter];
|
||||
// If iters are differnt
|
||||
if (!strideInDim[tensorID][dim].count(exprIter)) {
|
||||
ret.emplace_back(MismatchType::DLMismatch, tensorID);
|
||||
return;
|
||||
}
|
||||
auto exprStride = strideInDim[tensorID][dim].at(exprIter);
|
||||
// TODO: for stride and dilation
|
||||
if (exprStride != patternStride) {
|
||||
ret.emplace_back(MismatchType::DLMismatch, tensorID);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
}
|
||||
if (!ret.empty())
|
||||
return ret;
|
||||
// check output data layout
|
||||
// Output dim mismatch is not implemented.
|
||||
if (patternIT.rangeOp->getNumOutputDims() != rangeOp->getNumOutputDims()) {
|
||||
ret.emplace_back(Mismatch{MismatchType::OutputDimismatch, 0});
|
||||
return ret;
|
||||
}
|
||||
for (size_t i = 0; i < rangeOp->getLoopVarRanges().size(); ++i) {
|
||||
if (!mappingIter_r[patternIT.rangeOp->getLoopVar(i)]->equal(
|
||||
rangeOp->getLoopVar(i))) {
|
||||
ret.emplace_back(MismatchType::OutputDLMismatch, getNumInputs(),
|
||||
mappingIter_r);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
pair<PatternTensorMap, PatternIterRangeMap>
|
||||
IteratorTable::getReverseMap() const {
|
||||
return {tensorMap_r, iterToRange_r};
|
||||
}
|
||||
|
||||
bool Pattern::isAllUniqueAccess(int tensorID) const {
|
||||
for (const auto &iterInDim : iterInTensorDim[tensorID]) {
|
||||
if (iterInDim.size() != 1)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Expr MatmulPattern::buildExpr(
|
||||
const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges, string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const {
|
||||
// TODO support b
|
||||
assert(tensors.size() == 2);
|
||||
int b = 1;
|
||||
int m = tensors[0]->getShape(0), n = tensors[1]->getShape(0);
|
||||
int k = tensors[0]->getShape(1);
|
||||
// TODO: check strides
|
||||
// TODO: DLT for output?
|
||||
// FIXME: check the trans
|
||||
auto matmul = make_ref<MatmulNode>(expr, tensors[0], tensors[1], b, m, n, k,
|
||||
false, true);
|
||||
auto output = make_ref<TensorNode>(outputName, vector<int>{m, n},
|
||||
vector<int>{0, 0}, matmul);
|
||||
return output;
|
||||
}
|
||||
|
||||
const Pattern &MatmulPattern::getMatmulPattern() {
|
||||
static class MatmulPattern exprIT;
|
||||
static bool inited = false;
|
||||
if (!inited) {
|
||||
inited = true;
|
||||
int M = 224, N = 8, K = 16;
|
||||
auto m = make_ref<VarNode>("_Matmul_m");
|
||||
auto n = make_ref<VarNode>("_Matmul_n");
|
||||
auto k = make_ref<VarNode>("_Matmul_k");
|
||||
auto A = make_ref<TensorNode>("_Matmul_A", vector<int>({M, K}));
|
||||
auto B = make_ref<TensorNode>("_Matmul_B", vector<int>({N, K}));
|
||||
auto subA = makeSubscript(A, {m, k});
|
||||
auto subB = makeSubscript(B, {n, k});
|
||||
auto range = makeRangeOperator({{m, {0, M}}, {n, {0, N}}},
|
||||
{{k, {0, K}}}, subA * subB);
|
||||
auto success = exprIT.analyzeExpr(range);
|
||||
assert(success);
|
||||
exprIT.buildTable({0, 1});
|
||||
}
|
||||
return exprIT;
|
||||
}
|
||||
|
||||
const Pattern &ConvPattern::getPattern() {
|
||||
static class ConvPattern exprIT;
|
||||
static bool inited = false;
|
||||
if (!inited) {
|
||||
inited = true;
|
||||
// The shape is meaningless but cannot be zero IT building
|
||||
int N = 8, C = 16, H = 224, W = 224, F = 16, R = 3, S = 3;
|
||||
// auto n = make_ref<VarNode>("_Matmul_n");
|
||||
auto A = make_ref<TensorNode>("_Conv_A", vector<int>({N, C, H, W}));
|
||||
auto B = make_ref<TensorNode>("_Conv_K", vector<int>({F, C, R, S}));
|
||||
auto subA = makeSubscript(A, {n, c, h + r, w + s});
|
||||
auto subB = makeSubscript(B, {f, c, r, s});
|
||||
auto range = makeRangeOperator(
|
||||
{{n, {0, 0}}, {f, {0, 0}}, {h, {0, 0}}, {w, {0, 0}}},
|
||||
{{c, {0, 0}}, {r, {0, 0}}, {s, {0, 0}}}, subA * subB);
|
||||
auto success = exprIT.analyzeExpr(range);
|
||||
assert(success);
|
||||
exprIT.buildTable({0, 1});
|
||||
}
|
||||
return exprIT;
|
||||
}
|
||||
|
||||
Expr ConvPattern::buildExpr(
|
||||
const Expr &expr, const vector<Tensor> &tensors,
|
||||
const PatternIterRangeMap &varRanges, string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const {
|
||||
// calculate paddings
|
||||
const auto &rangeH = varRanges.at(h).second;
|
||||
const auto &rangeR = varRanges.at(r).second;
|
||||
const auto &rangeW = varRanges.at(w).second;
|
||||
const auto &rangeS = varRanges.at(s).second;
|
||||
auto offsetH =
|
||||
SimplifyExprVisitor().getConstant(exprIT.getSubscript(0)->getIndex(2));
|
||||
auto offsetW =
|
||||
SimplifyExprVisitor().getConstant(exprIT.getSubscript(0)->getIndex(3));
|
||||
int ph = calcPadding(tensors[0], 2, rangeH, rangeR, offsetH);
|
||||
int pw = calcPadding(tensors[0], 3, rangeW, rangeS, offsetW);
|
||||
|
||||
// TODO strided, dilated
|
||||
auto conv = make_ref<ConvNode>(expr, tensors[0], tensors[1], ph, pw);
|
||||
auto shape = conv->getShape();
|
||||
auto rangeOpShape = as<RangeOpNode>(expr)->getOutputShape();
|
||||
assert(shape.size() == rangeOpShape.size());
|
||||
dbg(shape, rangeOpShape);
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] != rangeOpShape[i]) {
|
||||
dbg("Warning: unmatched Conv output", shape, rangeOpShape);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto output =
|
||||
make_ref<TensorNode>(outputName, shape, vector<int>{0, 0, 0, 0}, conv);
|
||||
return output;
|
||||
}
|
||||
|
||||
RangeOp IteratorTable::getRangeOp() const { return rangeOp; }
|
||||
|
||||
#define DEF_CONV_VAR(a) \
|
||||
const Var ConvPattern::a = make_ref<VarNode>("_Conv_" #a)
|
||||
DEF_CONV_VAR(n);
|
||||
DEF_CONV_VAR(c);
|
||||
DEF_CONV_VAR(h);
|
||||
DEF_CONV_VAR(w);
|
||||
DEF_CONV_VAR(f);
|
||||
DEF_CONV_VAR(r);
|
||||
DEF_CONV_VAR(s);
|
||||
#undef DEF_CONV_VAR
|
||||
|
||||
int Pattern::calcPadding(const Tensor &tensor, int dim, Range rangeH,
|
||||
Range rangeR, int offset) const {
|
||||
int l = rangeH.first + rangeR.first + offset;
|
||||
int r = rangeH.second + rangeR.second - 1 + offset;
|
||||
int ret = max(0, max(0 - l, r - tensor->getShape(dim)));
|
||||
// dbg(l, r, rangeH, rangeR, offset, ret, tensor->getPadding(dim));
|
||||
// check OutOfBound
|
||||
assert(ret <= tensor->getPadding(dim));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#define DEF_SG2BMM_VAR(a) \
|
||||
const Var Sg2bmmPattern::a = make_ref<VarNode>("_Sg2bmm_" #a)
|
||||
DEF_SG2BMM_VAR(b);
|
||||
DEF_SG2BMM_VAR(m);
|
||||
DEF_SG2BMM_VAR(w);
|
||||
DEF_SG2BMM_VAR(k);
|
||||
#undef DEF_SG2BMM_VAR
|
||||
|
||||
const Pattern &Sg2bmmPattern::getPattern() {
|
||||
static class Sg2bmmPattern exprIT;
|
||||
static bool inited = false;
|
||||
if (!inited) {
|
||||
inited = true;
|
||||
// The shape is meaningless but cannot be zero IT building
|
||||
int Batch = 8, M = 32, K = 224, W = 2;
|
||||
// auto n = make_ref<VarNode>("_Matmul_n");
|
||||
auto A = make_ref<TensorNode>("_Sg2bmm_A", vector<int>{Batch, M, K});
|
||||
auto B = make_ref<TensorNode>("_Sg2bmm_B", vector<int>{Batch, M, K});
|
||||
auto subA = makeSubscript(A, {b, m, k});
|
||||
auto subB = makeSubscript(B, {b, m + w, k});
|
||||
auto range =
|
||||
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {-W, W + 1}}},
|
||||
{{k, {0, K}}}, subA * subB);
|
||||
auto success = exprIT.analyzeExpr(range);
|
||||
assert(success);
|
||||
exprIT.buildTableWithDefaultMap();
|
||||
}
|
||||
return exprIT;
|
||||
}
|
||||
|
||||
Expr Sg2bmmPattern::buildExpr(
|
||||
const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges, string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const {
|
||||
// calculate paddings
|
||||
assert(tensors.size() == 2);
|
||||
assert(tensors[0]->getDims() == 3 && tensors[1]->getDims() == 3);
|
||||
int Batch = tensors[0]->getShape(0);
|
||||
int M = tensors[0]->getShape(1);
|
||||
int K = tensors[0]->getShape(2);
|
||||
int W = getLength(varRanges.at(w).second) / 2;
|
||||
|
||||
auto op = make_ref<G2bmmNode>(expr, tensors[0], tensors[1], Batch, M, W, K);
|
||||
auto shape = op->getShape();
|
||||
auto rangeOpShape = as<RangeOpNode>(expr)->getOutputShape();
|
||||
assert(shape.size() == rangeOpShape.size());
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
assert(shape[i] == rangeOpShape[i]);
|
||||
}
|
||||
auto output =
|
||||
make_ref<TensorNode>(outputName, shape, vector<int>{0, 0, 0}, op);
|
||||
return output;
|
||||
}
|
||||
|
||||
#define DEF_LongformerGBMM_VAR(a) \
|
||||
const Var LongformerGBMMPattern::a = make_ref<VarNode>("_lo_" #a)
|
||||
DEF_LongformerGBMM_VAR(b);
|
||||
DEF_LongformerGBMM_VAR(m);
|
||||
DEF_LongformerGBMM_VAR(w);
|
||||
DEF_LongformerGBMM_VAR(n);
|
||||
#undef DEF_LongformerGBMM_VAR
|
||||
|
||||
const Pattern &LongformerGBMMPattern::getPattern() {
|
||||
static class LongformerGBMMPattern exprIT;
|
||||
static bool inited = false;
|
||||
if (!inited) {
|
||||
inited = true;
|
||||
// The shape is meaningless but cannot be zero IT building
|
||||
int Batch = 8, M = 32, N = 224, W = 2;
|
||||
auto A =
|
||||
make_ref<TensorNode>("_lo_A", vector<int>{Batch, M, 2 * W + 1});
|
||||
auto B = make_ref<TensorNode>("_lo_B", vector<int>{Batch, M, N});
|
||||
auto subA = makeSubscript(A, {b, m, w});
|
||||
auto subB = makeSubscript(B, {b, m + w, n});
|
||||
auto range =
|
||||
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, M}}},
|
||||
{{w, {-W, W + 1}}}, subA * subB);
|
||||
auto success = exprIT.analyzeExpr(range);
|
||||
assert(success);
|
||||
exprIT.buildTableWithDefaultMap();
|
||||
}
|
||||
return exprIT;
|
||||
}
|
||||
|
||||
Expr LongformerGBMMPattern::buildExpr(
|
||||
const Expr &expr, const vector<Tensor> &tensors,
|
||||
[[maybe_unused]] const PatternIterRangeMap &varRanges, string outputName,
|
||||
[[maybe_unused]] const IteratorTable &exprIT) const {
|
||||
// calculate paddings
|
||||
assert(tensors.size() == 2);
|
||||
assert(tensors[0]->getDims() == 3 && tensors[1]->getDims() == 3);
|
||||
int Batch = tensors[0]->getShape(0);
|
||||
int M = tensors[0]->getShape(1);
|
||||
assert(tensors[0]->getShape(2) % 2 == 1);
|
||||
int W = tensors[0]->getShape(2) / 2;
|
||||
int N = tensors[1]->getShape(2);
|
||||
|
||||
auto op = make_ref<GbmmNode>(expr, tensors[0], tensors[1], Batch, M, W, N);
|
||||
auto shape = op->getShape();
|
||||
auto rangeOpShape = as<RangeOpNode>(expr)->getOutputShape();
|
||||
assert(shape.size() == rangeOpShape.size());
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
assert(shape[i] == rangeOpShape[i]);
|
||||
}
|
||||
auto output = make_ref<TensorNode>(outputName, shape,
|
||||
vector<int>(shape.size(), 0), op);
|
||||
return output;
|
||||
}
|
||||
|
||||
const Pattern &getPattern(RoutineType targetOp) {
|
||||
switch (targetOp) {
|
||||
case RoutineType::MatmulNodeType:
|
||||
return MatmulPattern::getMatmulPattern();
|
||||
case RoutineType::ConvNodeType:
|
||||
return ConvPattern::getPattern();
|
||||
case RoutineType::G2bmmNodeType:
|
||||
return Sg2bmmPattern::getPattern();
|
||||
case RoutineType::GbmmNodeType:
|
||||
return LongformerGBMMPattern::getPattern();
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
}
|
||||
|
||||
string getPatternName(RoutineType targetOp) {
|
||||
switch (targetOp) {
|
||||
case RoutineType::MatmulNodeType:
|
||||
return "Matmul";
|
||||
case RoutineType::ConvNodeType:
|
||||
return "Conv";
|
||||
case RoutineType::G2bmmNodeType:
|
||||
return "G2bmm";
|
||||
case RoutineType::GbmmNodeType:
|
||||
return "Gbmm";
|
||||
default:
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
|
||||
Expr ConvPattern::getExpr(Tensor A, Tensor K, int N, int C, int H, int W, int F,
|
||||
int R, int S) {
|
||||
DEFINE_VAR(n);
|
||||
DEFINE_VAR(c);
|
||||
DEFINE_VAR(h);
|
||||
DEFINE_VAR(w);
|
||||
DEFINE_VAR(f);
|
||||
DEFINE_VAR(r);
|
||||
DEFINE_VAR(s);
|
||||
auto subA = makeSubscript(A, {n, c, h + r - R / 2, w + s - S / 2});
|
||||
auto subB = makeSubscript(K, {f, c, r, s});
|
||||
auto range =
|
||||
makeRangeOperator({{n, {0, N}}, {f, {0, F}}, {h, {0, H}}, {w, {0, W}}},
|
||||
{{c, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA * subB);
|
||||
return range;
|
||||
}
|
||||
|
||||
// Warn: F is the number of input channels, which is inversed compared with
|
||||
// normal Conv.
|
||||
// Input / output layouts: NHWF -> NHWC, Kenrel: RSFC
|
||||
Expr ConvTransPattern::getExpr(Tensor A, Tensor K, int N, int C, int H, int W,
|
||||
int F, int R, int S) {
|
||||
const int padding = 1 * (R - 1) - 1;
|
||||
assert(A->getPadding(2) == padding);
|
||||
assert(R == 4);
|
||||
assert(S == 4);
|
||||
const int OH = 2 * H, OW = 2 * W;
|
||||
DEFINE_VAR(n);
|
||||
DEFINE_VAR(c);
|
||||
DEFINE_VAR(f);
|
||||
DEFINE_VAR(r);
|
||||
DEFINE_VAR(s);
|
||||
DEFINE_VAR(x1);
|
||||
DEFINE_VAR(x2);
|
||||
DEFINE_VAR(y1);
|
||||
DEFINE_VAR(y2);
|
||||
DEFINE_VAR(i2);
|
||||
DEFINE_VAR(i4);
|
||||
DEFINE_VAR(h);
|
||||
DEFINE_VAR(w);
|
||||
// dilation * (kernel_size - 1) - padding
|
||||
// auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
|
||||
// vector<int>{0, padding, padding, 0});
|
||||
// auto K = make_ref<TensorNode>("K", vector<int>({R, S, F, C}));
|
||||
|
||||
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
|
||||
auto subK =
|
||||
makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c});
|
||||
// x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2
|
||||
|
||||
auto range1 = makeRangeOperator(
|
||||
{
|
||||
{n, {0, N}},
|
||||
{c, {0, C}},
|
||||
{x1, {0, OH / 2 + 1}},
|
||||
{x2, {0, 2}},
|
||||
{y1, {0, OW / 2 + 1}},
|
||||
{y2, {0, 2}},
|
||||
},
|
||||
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
|
||||
auto sub0 = makeSubscript(
|
||||
range1, {n, c, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2});
|
||||
auto range0 = makeRangeOperator(
|
||||
{{n, {0, N}}, {h, {0, OH}}, {w, {0, OW}}, {c, {0, C}}}, {}, sub0);
|
||||
return range0;
|
||||
}
|
||||
|
||||
pair<Expr, pair<Tensor, Tensor>> Sg2bmmPattern::getExpr(int Batch, int M, int K,
|
||||
int W, int D) {
|
||||
DEFINE_VAR(b);
|
||||
DEFINE_VAR(m);
|
||||
DEFINE_VAR(w);
|
||||
DEFINE_VAR(k);
|
||||
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, 0, 0});
|
||||
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, D * W, 0});
|
||||
|
||||
auto subA = makeSubscript(A, {b, m, k});
|
||||
auto subB = makeSubscript(B, {b, m + D * (w - W), k});
|
||||
auto range =
|
||||
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {0, 2 * W + 1}}},
|
||||
{{k, {0, K}}}, subA * subB);
|
||||
return {range, {A, B}};
|
||||
}
|
||||
|
||||
pair<Expr, pair<Tensor, Tensor>>
|
||||
LongformerGBMMPattern::getExpr(int Batch, int M, int W, int K, int dilation) {
|
||||
DEFINE_VAR(b);
|
||||
DEFINE_VAR(m);
|
||||
DEFINE_VAR(w);
|
||||
DEFINE_VAR(k);
|
||||
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, 2 * W + 1}),
|
||||
vector<int>{0, 0, 0});
|
||||
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, dilation * W, 0});
|
||||
auto subA = makeSubscript(A, {b, m, w});
|
||||
auto subB = makeSubscript(B, {b, m + dilation * w - dilation * W, n});
|
||||
auto range = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, K}}},
|
||||
{{w, {0, 2 * W + 1}}}, subA * subB);
|
||||
return {range, {A, B}};
|
||||
}
|
||||
|
||||
pair<Expr, pair<Tensor, Tensor>> MatmulPattern::getExpr(bool transA,
|
||||
bool transB, int Batch,
|
||||
int M, int N, int K) {
|
||||
DEFINE_VAR(b);
|
||||
DEFINE_VAR(m);
|
||||
DEFINE_VAR(n);
|
||||
DEFINE_VAR(k);
|
||||
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, 0, 0});
|
||||
auto B = make_ref<TensorNode>("B", vector<int>({Batch, K, N}),
|
||||
vector<int>{0, 0, 0});
|
||||
auto subA = makeSubscript(A, {b, m, k});
|
||||
auto subB = makeSubscript(B, {b, k, n});
|
||||
auto range = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, N}}},
|
||||
{{k, {0, K}}}, subA * subB);
|
||||
return {range, {A, B}};
|
||||
}
|
||||
|
||||
#undef DEFINE_VAR
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,722 @@
|
|||
#include "nnet/nmutator.h"
|
||||
#include "nnet/Visitor/FullPrinterVisitor.h"
|
||||
#include "nnet/Visitor/GetTensorsVisitor.h"
|
||||
#include "nnet/Visitor/MatchReshapeVisitor.h"
|
||||
#include "nnet/derivator.h"
|
||||
|
||||
#ifdef ABC
|
||||
namespace infini {
|
||||
|
||||
NMutator::NMutator() {}
|
||||
|
||||
NMutator::NMutator(const std::vector<int> &derivationRules)
|
||||
: mode{Mode::RuleBased}, derivationRules{derivationRules} {}
|
||||
|
||||
NMutator::~NMutator() {}
|
||||
|
||||
void NMutator::setToNaiveMembound() { mode = Mode::ToNaiveMembound; }
|
||||
|
||||
vector<Graph> NMutator::run(const Graph &in_graph) {
|
||||
vector<Graph> out_graphs;
|
||||
// Test helper: naively transform one Op to Membound
|
||||
if (mode == Mode::ToNaiveMembound) {
|
||||
runSingleOpToNaiveMembound(in_graph, out_graphs);
|
||||
dbg(out_graphs.size());
|
||||
return;
|
||||
}
|
||||
// // Hack for HetConv fusion
|
||||
// if (statGraph(in_graph) == NMutator::SGType::HetConv) {
|
||||
// dbg("Start fuse HetConv");
|
||||
// out_graphs.emplace_back(fuseHetConv(nullptr, in_graph));
|
||||
// }
|
||||
// Clear input names maps with tensor
|
||||
inputsNameNToTensorT.clear();
|
||||
std::vector<Operator> computeOps;
|
||||
in_graph->getComputeOps(computeOps);
|
||||
// assert(computeOps.size() == 1);
|
||||
if (computeOps.size() == 1)
|
||||
runSingleOp(in_graph, out_graphs);
|
||||
// FIXME: runMultipleOps results in segfault
|
||||
// else
|
||||
// runMultipleOps(in_graph, out_graphs);
|
||||
}
|
||||
|
||||
void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
|
||||
std::vector<Graph> &out_graphs) {
|
||||
std::vector<Operator> computeOps;
|
||||
in_graph->getComputeOps(computeOps);
|
||||
assert(computeOps.size() == 1);
|
||||
const auto &computeOp = computeOps[0];
|
||||
auto g = std::make_shared<infini::Graph>();
|
||||
auto expr = opToExpression(computeOp);
|
||||
auto inputsN = nnet::GetTensorsVisitor().get(expr);
|
||||
dbg(inputsN);
|
||||
g->membound(computeOp->getInputs(), computeOp->getOutputs(),
|
||||
{inputsN.at("A"), inputsN.at("K")}, expr, 0);
|
||||
auto subgraph = new Graph(g->getOperators());
|
||||
subgraph->print();
|
||||
out_graphs.emplace_back(subgraph);
|
||||
}
|
||||
|
||||
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
||||
std::vector<Operator> computeOps;
|
||||
in_graph->getComputeOps(computeOps);
|
||||
if (infini::Graph g = transformTConv1x1(computeOps[0])) {
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
out_graphs.emplace_back(graph);
|
||||
return;
|
||||
}
|
||||
// Commented for debug, not implemented yet
|
||||
// if (infini::Graph g = transformTConv3x3(computeOps[0])) {
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// out_graphs.emplace_back(graph);
|
||||
// return;
|
||||
// }
|
||||
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
out_graphs.emplace_back(graph);
|
||||
return;
|
||||
}
|
||||
// if (infini::Graph g = transformConv1x1(computeOps[0])) {
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// out_graphs.emplace_back(graph);
|
||||
// return;
|
||||
// }
|
||||
// if (infini::Graph g = transformConv1xk(computeOps[0])) {
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// out_graphs.emplace_back(graph);
|
||||
// return;
|
||||
// }
|
||||
|
||||
auto expr = opToExpression(computeOps[0]);
|
||||
if (!expr)
|
||||
return;
|
||||
|
||||
nnet::Derivator derivator(maxDepth);
|
||||
nnet::Formula conv_9x9(expr, 0);
|
||||
// const std::vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}; // Tconv
|
||||
// const std::vector<int> rules{1, 7, 7, 2, 8, 6, 6}; // G2BMM
|
||||
if (mode == Mode::Normal) {
|
||||
derivator.search(conv_9x9, 0);
|
||||
} else if (mode == Mode::RuleBased) {
|
||||
dbg(derivationRules);
|
||||
derivator.ruleBasedDFS(conv_9x9, 0, derivationRules);
|
||||
} else
|
||||
nnet_assert(0, "Unknown mode");
|
||||
const auto &candidates = derivator.getCandidates();
|
||||
dbg(candidates.size());
|
||||
// derivator.print();
|
||||
for (const auto &candidate : candidates) {
|
||||
// dbg(nnet::FullPrinterVisitor().print(candidate.root));
|
||||
if (auto g = expressionToGraph(candidate.root, in_graph)) {
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
out_graphs.emplace_back(graph);
|
||||
}
|
||||
// break; // HACK:Debug only for the first subgraph
|
||||
}
|
||||
// dbg(out_graphs);
|
||||
// for (auto graph : out_graphs) {
|
||||
// graph->print();
|
||||
// }
|
||||
cntStates += derivator.getNumIntermediateStates();
|
||||
cntCandidates += derivator.getNumCandidates();
|
||||
}
|
||||
|
||||
void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
|
||||
std::cout << "run multiple ops" << std::endl;
|
||||
in_graph->print();
|
||||
std::cout << std::endl;
|
||||
|
||||
std::vector<Operator> computeOps;
|
||||
dbg(computeOps);
|
||||
in_graph->getComputeOps(computeOps);
|
||||
nnet::VecExpr exprs;
|
||||
for (const auto &op : computeOps)
|
||||
exprs.emplace_back(opToExpression(op));
|
||||
dbg(exprs);
|
||||
|
||||
nnet::Derivator derivator;
|
||||
nnet::MultiFormulas origin(exprs, 0);
|
||||
bool canCombine = derivator.stageCombination(origin, 0);
|
||||
dbg(canCombine);
|
||||
const auto matmul0 = dynamic_cast<MatmulOp *>(computeOps[0]);
|
||||
assert(matmul0);
|
||||
// Build merged graph
|
||||
auto g = new infini::Graph();
|
||||
std::vector<Tensor *> inputsT, weightsT, outputsT;
|
||||
for (const auto &opT : computeOps) {
|
||||
inputsT.emplace_back(opT->getInputs(0));
|
||||
weightsT.emplace_back(opT->getInputs(1));
|
||||
outputsT.emplace_back(opT->getOutput());
|
||||
}
|
||||
const auto concat1 = g->concat(inputsT, 0);
|
||||
const auto concat2 = g->concat(weightsT, 0);
|
||||
const auto matmul = g->matmul(concat1->getOutput(), concat2->getOutput(),
|
||||
matmul0->getTransA(), matmul0->getTransB());
|
||||
g->split(matmul->getOutput(), outputsT, 0, computeOps.size());
|
||||
// Build computation graph in PET:
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
out_graphs.emplace_back(graph);
|
||||
// DEBUG
|
||||
dbg(out_graphs);
|
||||
for (auto graph : out_graphs) {
|
||||
graph->print();
|
||||
}
|
||||
}
|
||||
|
||||
NMutator::SGType NMutator::statGraph(Graph sg) {
|
||||
auto ops = sg->getOperators();
|
||||
switch (ops.size()) {
|
||||
case 0: {
|
||||
return Empty;
|
||||
break;
|
||||
}
|
||||
|
||||
case 1: {
|
||||
if (ops[0]->getType() == Operator::Conv) {
|
||||
auto weight = ops[0]->getInputs()[1];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
if (((ConvOp *)sg->getOperators()[0])->getDh() == 1 &&
|
||||
((ConvOp *)sg->getOperators()[0])->getDw() == 1 && r == 1 &&
|
||||
s == 1) {
|
||||
return Conv1X1;
|
||||
} else if (((ConvOp *)sg->getOperators()[0])->getDh() == 2 ||
|
||||
((ConvOp *)sg->getOperators()[0])->getDw() == 2) {
|
||||
return DilatedConv;
|
||||
} else {
|
||||
const Dim &inDim = ops[0]->getInputs()[0]->getDims();
|
||||
const Dim &wDim = ops[0]->getInputs()[1]->getDims();
|
||||
if (inDim[2] % 2 == 1 && inDim[3] % 2 == 1)
|
||||
return NormalOddConv;
|
||||
else if (wDim[2] != wDim[3])
|
||||
return TransKernelConv;
|
||||
else
|
||||
return NormalConv;
|
||||
}
|
||||
} else if (ops[0]->getType() == Operator::Matmul) {
|
||||
return NormalMatmul;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
auto ty = ops[0]->getType();
|
||||
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i) {
|
||||
if (ops[i]->getType() != ty)
|
||||
return Others;
|
||||
}
|
||||
if (ty == Operator::Conv) {
|
||||
std::vector<ConvOp *> convs;
|
||||
for (auto op : ops)
|
||||
convs.emplace_back(dynamic_cast<ConvOp *>(op));
|
||||
// TODO: 1x1 conv enlarge. 1x1 conv has 0 padding
|
||||
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i)
|
||||
if (!convs[i]->same(*convs[0]))
|
||||
return Others;
|
||||
auto inDim = ops[0]->getInputs(0)->getDims();
|
||||
// TODO: enlarge input tensor?
|
||||
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i)
|
||||
if (ops[i]->getInputs(0)->getDims() != inDim)
|
||||
return Others;
|
||||
auto weightDim = ops[0]->getInputs(1)->getDims();
|
||||
auto groupFlag = true;
|
||||
// TODO: kernel enlarge to group?
|
||||
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i) {
|
||||
auto wDim = ops[i]->getInputs(1)->getDims();
|
||||
if (!(wDim[1] == weightDim[1] && wDim[2] == weightDim[2] &&
|
||||
wDim[3] == weightDim[3] && wDim[2] == wDim[3])) {
|
||||
groupFlag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (groupFlag)
|
||||
return GroupConv;
|
||||
// Hack for HetConv
|
||||
if (ops.size() == 2) {
|
||||
auto w1Dim = weightDim, w2Dim = ops[1]->getInputs(1)->getDims();
|
||||
auto hConvFlag = false;
|
||||
if (w1Dim[0] == w2Dim[0] && w1Dim[1] == w2Dim[1]) {
|
||||
if (w1Dim[2] == 3 && w1Dim[3] == 3 && w2Dim[2] == 1 &&
|
||||
w2Dim[3] == 1) {
|
||||
hConvFlag = true;
|
||||
}
|
||||
if (w1Dim[2] == 1 && w1Dim[3] == 1 && w2Dim[2] == 3 &&
|
||||
w2Dim[3] == 3) {
|
||||
hConvFlag = true;
|
||||
}
|
||||
}
|
||||
if (hConvFlag) {
|
||||
// std::cout << "[nmutator stat graph]Het Conv found!"
|
||||
// << std::endl;
|
||||
// ops[0]->print();
|
||||
// std::cout << std::endl;
|
||||
// ops[1]->print();
|
||||
// std::cout << std::endl;
|
||||
return HetConv;
|
||||
}
|
||||
}
|
||||
auto transGroupFlag = true;
|
||||
// TODO: transpose group conv with different f dim?
|
||||
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i) {
|
||||
auto wDim = ops[i]->getInputs(1)->getDims();
|
||||
if (!(wDim[0] == weightDim[0] && wDim[1] == weightDim[1] &&
|
||||
((wDim[2] == weightDim[2] && wDim[3] == weightDim[3]) ||
|
||||
(wDim[2] == weightDim[3] && wDim[3] == weightDim[2])))) {
|
||||
transGroupFlag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (transGroupFlag)
|
||||
return TransposeGroupConv;
|
||||
} else if (ty == Operator::Matmul) {
|
||||
// check same input shape or not
|
||||
for (int i = 0; i < (int)ops.size() - 1; ++i) {
|
||||
assert(dynamic_cast<MatmulOp *>(ops[i])->getTransA() ==
|
||||
dynamic_cast<MatmulOp *>(ops[i + 1])->getTransA());
|
||||
assert(dynamic_cast<MatmulOp *>(ops[i])->getTransB() ==
|
||||
dynamic_cast<MatmulOp *>(ops[i + 1])->getTransB());
|
||||
if (ops[i]->getInputs()[0]->getDims() !=
|
||||
ops[i + 1]->getInputs()[0]->getDims()) {
|
||||
return Others;
|
||||
}
|
||||
if (ops[i]->getInputs()[1]->getDims() !=
|
||||
ops[i + 1]->getInputs()[1]->getDims()) {
|
||||
return Others;
|
||||
}
|
||||
}
|
||||
return BatchMatmul;
|
||||
}
|
||||
// TODO: others?
|
||||
break;
|
||||
}
|
||||
|
||||
return Others;
|
||||
}
|
||||
|
||||
uint64_t NMutator::computeHashForSingleComputeOp(const Operator op) {
|
||||
if (op->getType() == Operator::Conv) {
|
||||
auto conv = dynamic_cast<const ConvOp *>(op);
|
||||
auto hash = conv->getHash();
|
||||
auto inputDim = conv->getInputs()[0]->getDims();
|
||||
auto weightDim = conv->getOutputs()[0]->getDims();
|
||||
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
|
||||
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
|
||||
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
|
||||
weightDim[2] * 10000169 + weightDim[3] * 10000189;
|
||||
return hash;
|
||||
} else if (op->getType() == Operator::ConvTrans) {
|
||||
auto conv = dynamic_cast<const ConvTransOp *>(op);
|
||||
auto hash = conv->getHash();
|
||||
auto inputDim = conv->getInputs()[0]->getDims();
|
||||
auto weightDim = conv->getOutputs()[0]->getDims();
|
||||
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
|
||||
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
|
||||
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
|
||||
weightDim[2] * 10000169 + weightDim[3] * 10000189;
|
||||
return hash;
|
||||
} else if (op->getType() == Operator::Matmul) {
|
||||
static uint64_t matmulhash = 0;
|
||||
return matmulhash++;
|
||||
} else if (op->getType() == Operator::G2BMM) {
|
||||
auto g2bmm = dynamic_cast<const G2BMMOp *>(op);
|
||||
auto hash = g2bmm->getHash();
|
||||
auto inputDim = g2bmm->getInputs()[0]->getDims();
|
||||
auto weightDim = g2bmm->getOutputs()[0]->getDims();
|
||||
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
|
||||
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
|
||||
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
|
||||
weightDim[2] * 10000169 + weightDim[3] * 10000189;
|
||||
return hash;
|
||||
} else if (op->getType() == Operator::GBMML) {
|
||||
auto gbmml = dynamic_cast<const GBMMLOp *>(op);
|
||||
auto hash = gbmml->getHash();
|
||||
auto inputDim = gbmml->getInputs()[0]->getDims();
|
||||
auto weightDim = gbmml->getOutputs()[0]->getDims();
|
||||
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
|
||||
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
|
||||
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
|
||||
weightDim[2] * 10000169 + weightDim[3] * 10000189;
|
||||
return hash;
|
||||
} else {
|
||||
// Not impl
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
nnet::Expr NMutator::opToExpression(Operator op) {
|
||||
if (auto convOp = dynamic_cast<ConvOp *>(op)) {
|
||||
const auto &inputs = convOp->getInputs();
|
||||
const auto &AT = inputs[0];
|
||||
const auto &KT = inputs[1];
|
||||
const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi, ac] =
|
||||
convOp->getArgs(0);
|
||||
dbg(n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw);
|
||||
if (!(sh == 1 && sw == 1 && dh == 1 && dw == 1))
|
||||
return nullptr;
|
||||
assert(sh == 1 && sw == 1 && dh == 1 && dw == 1);
|
||||
inputsNameNToTensorT["A"] = AT;
|
||||
inputsNameNToTensorT["K"] = KT;
|
||||
const auto A = nnet::makeTensor("A", AT->getDims(),
|
||||
std::vector<int>{0, 0, ph, pw});
|
||||
const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s);
|
||||
} else if (auto convOp = dynamic_cast<ConvTransOp *>(op)) {
|
||||
const auto &AT = convOp->getInputs()[0];
|
||||
const auto &KT = convOp->getInputs()[1];
|
||||
inputsNameNToTensorT["A"] = AT;
|
||||
inputsNameNToTensorT["K"] = KT;
|
||||
const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi, ac] =
|
||||
convOp->getArgs(0);
|
||||
if (r != 4) {
|
||||
dbg("ConvTranspose R!=4. Skipped.", r);
|
||||
return nullptr;
|
||||
}
|
||||
int padding = 1 * (r - 1) - 1;
|
||||
const auto A = nnet::makeTensor(
|
||||
"A", AT->getDims(), std::vector<int>{0, padding, padding, 0});
|
||||
const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
return nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s);
|
||||
} else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(op)) {
|
||||
const auto &AT = g2bmmOp->getInputs()[0];
|
||||
const auto &BT = g2bmmOp->getInputs()[1];
|
||||
const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
|
||||
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
|
||||
inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
return expr;
|
||||
} else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(op)) {
|
||||
const auto &AT = gbmmlOp->getInputs()[0];
|
||||
const auto &BT = gbmmlOp->getInputs()[1];
|
||||
const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
|
||||
inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
dbg(b, m, w, k, dilation, expr);
|
||||
return expr;
|
||||
} else if (auto matmulOp = dynamic_cast<MatmulOp *>(op)) {
|
||||
const auto &AT = matmulOp->getInputs()[0];
|
||||
const auto &BT = matmulOp->getInputs()[1];
|
||||
const auto [transA, transB, b, m, n, k] = matmulOp->getArgs();
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::MatmulPattern::getExpr(transA, transB, b, m, n, k);
|
||||
inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
dbg(b, m, n, k, expr);
|
||||
return expr;
|
||||
}
|
||||
// else if (auto transposeOp = dynamic_cast<TransposeOp *>(op)) {
|
||||
// return transposeOpToExpression(transposeOp);
|
||||
// }
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
infini::Graph NMutator::fuseHetConv(nnet::Expr expr, Graph in_graph) {
|
||||
// Conv3x3+Conv1x1 => Gemm(nhw, f(rs+1), c) + Reduce
|
||||
auto g = std::make_shared<infini::Graph>();
|
||||
in_graph->print();
|
||||
assert(in_graph->getInputs().size() == 3);
|
||||
auto input = in_graph->getOperators()[0]->getInputs(0);
|
||||
auto conv = dynamic_cast<ConvOp *>(in_graph->getOperators()[0]);
|
||||
auto output = conv->getOutput();
|
||||
// auto input = g->reshape(input);
|
||||
auto inputTrans = g->transpose(input, 0, {-1, {0, 2, 3}, 1}, -1);
|
||||
// dbg(inputTrans->getOutput()->getDims());
|
||||
const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, G, bi, ac] =
|
||||
conv->getArgs(0);
|
||||
auto weight = g->tensor({1, c, f * (3 * 3 + 1)});
|
||||
dbg(weight->getDims());
|
||||
auto matmul = g->matmul(inputTrans->getOutput(), weight, false, false);
|
||||
auto bias = g->tensor({f});
|
||||
const double size = n * f * h * w * (3 * 3 + 1) * 4;
|
||||
// FIXME: add NNET tensors for verfication
|
||||
auto membound =
|
||||
g->membound({matmul->getOutput(), bias}, {output}, {}, nullptr,
|
||||
memboundTime(size), "Reduce_conv3x3+1x1");
|
||||
dbg(n, f, h, w);
|
||||
dynamic_cast<MemBoundOp *>(membound)->setNFHW(n, f, h, w);
|
||||
|
||||
return new Graph(g->getOperators());
|
||||
}
|
||||
|
||||
infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
|
||||
auto g = new infini::Graph();
|
||||
nnet::FullPrinterVisitor fullVisitor;
|
||||
const auto &tensorQueueN = fullVisitor.traverse(expr);
|
||||
// Build tensors: Skip the first one, which is output
|
||||
auto nameNToTensorT = inputsNameNToTensorT;
|
||||
for (size_t i = 1; i < tensorQueueN.size(); ++i) {
|
||||
const auto &[nameN, routineN, tensorN] = tensorQueueN[i];
|
||||
// dbg(nameN, routineN, tensorN);
|
||||
if (!routineN) {
|
||||
// This is an inputs
|
||||
assert(nameNToTensorT.count(nameN));
|
||||
} else {
|
||||
assert(!nameNToTensorT.count(nameN));
|
||||
nameNToTensorT[nameN] = g->tensor(tensorN->getShape());
|
||||
}
|
||||
}
|
||||
const auto &outputsPET = in_graph->getOutputs();
|
||||
if (outputsPET.size() != 1) {
|
||||
nnet_unimplemented_continue();
|
||||
return nullptr;
|
||||
}
|
||||
nameNToTensorT[std::get<0>(tensorQueueN.at(0))] = outputsPET[0];
|
||||
// Build computation graph in PET:
|
||||
for (int i = tensorQueueN.size() - 1; i >= 0; --i) {
|
||||
const auto &[outputNameN, routineN, tensorN] = tensorQueueN[i];
|
||||
if (!routineN)
|
||||
continue;
|
||||
// dbg(outputNameN, routineN, tensorN, routineN->getType());
|
||||
if (auto op = nnet::as<nnet::ConvNode>(routineN)) {
|
||||
// g->conv(i8, w9, 2, 2);
|
||||
std::vector<nnet::Tensor> inputsN = op->getInputs();
|
||||
auto A = nameNToTensorT.at(inputsN[0]->getName());
|
||||
auto K = nameNToTensorT.at(inputsN[1]->getName());
|
||||
auto output = nameNToTensorT.at(outputNameN);
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
|
||||
g->conv(A, K, output, ph, pw, sh, sw, dh, dw);
|
||||
} else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
|
||||
assert(op->getInputs().size() == 1);
|
||||
nnet::MatchReshapeVisitor matchReshapeVisitor;
|
||||
if (matchReshapeVisitor(op->getExpr())) {
|
||||
auto input =
|
||||
nameNToTensorT.at(op->getInputs().at(0)->getName());
|
||||
auto output = nameNToTensorT.at(outputNameN);
|
||||
g->reshape(input, output);
|
||||
} else {
|
||||
TensorVec inputsPET;
|
||||
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
|
||||
for (const auto &inputN : op->getInputs())
|
||||
inputsPET.emplace_back(
|
||||
nameNToTensorT.at(inputN->getName()));
|
||||
// Re-estimate time here.
|
||||
ssize_t cnt = 0;
|
||||
for (const auto tensor : inputsPET)
|
||||
cnt += tensor->size();
|
||||
for (const auto tensor : outputsPET)
|
||||
cnt += tensor->size();
|
||||
g->membound(inputsPET, outputsPET, op->getInputs(),
|
||||
op->getExpr(), memboundTime(cnt));
|
||||
}
|
||||
} else if (auto op = nnet::as<nnet::MatmulNode>(routineN)) {
|
||||
assert(op->getInputs().size() == 2);
|
||||
nnet::Tensor AN = op->getInputs()[0];
|
||||
nnet::Tensor BN = op->getInputs()[1];
|
||||
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
|
||||
nameNToTensorT.at(BN->getName())};
|
||||
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
|
||||
const auto &[b, m, n, k, transa, transb] = op->getArgs();
|
||||
g->matmul(inputsPET[0], inputsPET[1], outputsPET[0], transa,
|
||||
transb);
|
||||
} else if (auto op = nnet::as<nnet::G2bmmNode>(routineN)) {
|
||||
assert(op->getInputs().size() == 2);
|
||||
nnet::Tensor AN = op->getInputs()[0];
|
||||
nnet::Tensor BN = op->getInputs()[1];
|
||||
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
|
||||
nameNToTensorT.at(BN->getName())};
|
||||
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
|
||||
const auto &[b, m, w, k, dilation] = op->getArgs();
|
||||
g->g2bmm(inputsPET[0], inputsPET[1], outputsPET[0], w, dilation);
|
||||
} else if (auto op = nnet::as<nnet::GbmmNode>(routineN)) {
|
||||
assert(op->getInputs().size() == 2);
|
||||
nnet::Tensor AN = op->getInputs()[0];
|
||||
nnet::Tensor BN = op->getInputs()[1];
|
||||
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
|
||||
nameNToTensorT.at(BN->getName())};
|
||||
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
|
||||
const auto &[b, m, w, n, dilation] = op->getArgs();
|
||||
g->gbmml(inputsPET[0], inputsPET[1], outputsPET[0], dilation);
|
||||
}
|
||||
}
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
return graph;
|
||||
}
|
||||
|
||||
Graph NMutator::transformDialtedConv(Operator op) {
|
||||
if (auto convOp = dynamic_cast<ConvOp *>(op)) {
|
||||
if (convOp->getPh() == convOp->getDh() && convOp->getSh() == 1 &&
|
||||
convOp->getDh() > 1 && convOp->getDh() == convOp->getDw()) {
|
||||
const int d = convOp->getDh();
|
||||
assert(convOp->getInputs()[0]->getDims()[2] % d == 0);
|
||||
auto g = new infini::Graph();
|
||||
auto inputDims = convOp->getInputs(0)->getDims();
|
||||
auto weightDims = convOp->getInputs(1)->getDims();
|
||||
auto outputDims = convOp->getOutput()->getDims();
|
||||
auto newA = g->tensor({inputDims[0] * d * d, inputDims[1],
|
||||
inputDims[2] / d, inputDims[3] / d});
|
||||
// auto newW = g->tensor(
|
||||
// {weightDims[0] * weightDims[1] * weightDims[3],
|
||||
// weightDims[2]});
|
||||
auto newO =
|
||||
g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
|
||||
weightDims[0] * weightDims[1] * weightDims[3]});
|
||||
g->membound(
|
||||
{convOp->getInputs(0)}, {newA}, {}, nullptr,
|
||||
memboundTime(convOp->getInputs(0)->size() + newA->size()),
|
||||
"DConv Tranpose Input");
|
||||
// g->membound({convOp->getInputs(1)}, {newW}, {}, nullptr, 0,
|
||||
// "Tranpose Weight");
|
||||
g->conv(newA, convOp->getInputs(1), newO, 1, 1, 1, 1);
|
||||
g->reshape(newO, convOp->getOutput());
|
||||
dbg(newO->size(), convOp->getOutput()->size());
|
||||
assert(newO->size() == convOp->getOutput()->size());
|
||||
g->membound(
|
||||
{newO}, {convOp->getOutput()}, {}, nullptr,
|
||||
memboundTime(newO->size() + convOp->getOutput()->size()),
|
||||
"DConv Tranpose Output");
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
return graph;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
double NMutator::memboundTime(ssize_t cnt) {
|
||||
return double(cnt) * 4 / bandwidth * 1000; // millisecond
|
||||
}
|
||||
|
||||
double NMutator::memboundTime(const Dim &dims) {
|
||||
return memboundTime(dims.size());
|
||||
}
|
||||
|
||||
Graph NMutator::transformTConv3x3(Operator op) {
|
||||
if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
|
||||
dbg(tconvOp->getInputs()[1]->getDims());
|
||||
if (tconvOp->getPh() == 1 && tconvOp->getSh() == 2 &&
|
||||
tconvOp->getInputs()[1]->getDims()[0] == 3 &&
|
||||
tconvOp->getInputs()[1]->getDims()[1] == 3) {
|
||||
auto g = new infini::Graph();
|
||||
auto inputDims = tconvOp->getInputs(0)->getDims();
|
||||
auto weightDims = tconvOp->getInputs(1)->getDims();
|
||||
auto outputDims = tconvOp->getOutput()->getDims();
|
||||
// NHWF
|
||||
auto newA = g->tensor(
|
||||
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
|
||||
// RSFC
|
||||
auto newW = g->tensor(
|
||||
{weightDims[0] * weightDims[1] * weightDims[3], weightDims[2]});
|
||||
auto newO =
|
||||
g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
|
||||
weightDims[0] * weightDims[1] * weightDims[3]});
|
||||
g->reshape(tconvOp->getInputs(0), newA);
|
||||
g->reshape(tconvOp->getInputs(1), newW);
|
||||
g->matmul(newA, newW, newO, 0, 1);
|
||||
// g->reshape(newO, tconvOp->getOutput());
|
||||
tconvOp->print();
|
||||
dbg(newO->size() * 4, tconvOp->getOutput()->size() * 9);
|
||||
assert(newO->size() * 4 == tconvOp->getOutput()->size() * 9);
|
||||
g->membound(
|
||||
{newO}, {tconvOp->getOutput()}, {}, nullptr,
|
||||
memboundTime(newO->size() + tconvOp->getOutput()->size()),
|
||||
"TConv3x3 reduce");
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
return graph;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Graph NMutator::transformTConv1x1(Operator op) {
|
||||
if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
|
||||
if (tconvOp->getPh() == 0 && tconvOp->getSh() == 1) {
|
||||
auto g = new infini::Graph();
|
||||
auto inputDims = tconvOp->getInputs(0)->getDims();
|
||||
auto weightDims = tconvOp->getInputs(1)->getDims();
|
||||
auto outputDims = tconvOp->getOutput()->getDims();
|
||||
auto newA = g->tensor(
|
||||
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
|
||||
auto newW = g->tensor(
|
||||
{weightDims[0] * weightDims[1] * weightDims[3], weightDims[2]});
|
||||
auto newO =
|
||||
g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
|
||||
weightDims[0] * weightDims[1] * weightDims[3]});
|
||||
g->reshape(tconvOp->getInputs(0), newA);
|
||||
g->reshape(tconvOp->getInputs(1), newW);
|
||||
g->matmul(newA, newW, newO, 0, 1);
|
||||
g->reshape(newO, tconvOp->getOutput());
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
return graph;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConv1x1(Operator op) {
|
||||
auto convOp = dynamic_cast<ConvOp *>(op);
|
||||
if (!convOp)
|
||||
return nullptr;
|
||||
if (convOp->getPh() == 0 && convOp->getSh() == 1 &&
|
||||
convOp->getInputs()[1]->getDims()[2] == 1 &&
|
||||
convOp->getInputs()[1]->getDims()[3] == 1) {
|
||||
// Transpose is requrired for BS>1
|
||||
// if (convOp->getInputs()[0]->getDims()[0] == 1) {
|
||||
auto g = new infini::Graph();
|
||||
auto inputDims = convOp->getInputs(0)->getDims();
|
||||
auto weightDims = convOp->getInputs(1)->getDims();
|
||||
auto outputDims = convOp->getOutput()->getDims();
|
||||
auto newA = g->tensor(
|
||||
{inputDims[1], inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
auto newW = g->tensor({weightDims[0], weightDims[1]});
|
||||
auto newO = g->tensor(
|
||||
{weightDims[0], inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
g->reshape(convOp->getInputs(0), newA);
|
||||
g->reshape(convOp->getInputs(1), newW);
|
||||
g->matmul(newW, newA, newO, 0, 0);
|
||||
g->reshape(newO, convOp->getOutput());
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
return graph;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConv1xk(Operator op) {
|
||||
auto convOp = dynamic_cast<ConvOp *>(op);
|
||||
if (!convOp)
|
||||
return nullptr;
|
||||
if (convOp->getSh() != 1 || convOp->getSw() != 1)
|
||||
return nullptr;
|
||||
bool a = convOp->getInputs()[1]->getDims()[2] == 1;
|
||||
bool b = convOp->getInputs()[1]->getDims()[3] == 1;
|
||||
if (!(a ^ b))
|
||||
return nullptr;
|
||||
convOp->print();
|
||||
auto g = new infini::Graph();
|
||||
auto inputDims = convOp->getInputs(0)->getDims();
|
||||
auto weightDims = convOp->getInputs(1)->getDims();
|
||||
auto outputDims = convOp->getOutput()->getDims();
|
||||
auto newA =
|
||||
g->tensor({inputDims[0] * inputDims[2] * inputDims[3], inputDims[1]});
|
||||
auto newW = g->tensor(
|
||||
{weightDims[0] * weightDims[2] * weightDims[3], weightDims[1]});
|
||||
auto newO = g->tensor({weightDims[0] * weightDims[2] * weightDims[3],
|
||||
inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
// g->reshape(convOp->getInputs(0), newA);
|
||||
g->membound({convOp->getInputs(0)}, {newA}, {}, nullptr,
|
||||
memboundTime(convOp->getInputs(0)->size() + newA->size()),
|
||||
"1xk input reshape");
|
||||
g->reshape(convOp->getInputs(1), newW);
|
||||
|
||||
g->matmul(newW, newA, newO, 0, 1);
|
||||
g->membound({newO}, {convOp->getOutput()}, {}, nullptr,
|
||||
memboundTime(newW->size() + convOp->getOutput()->size()),
|
||||
"1xk reduce");
|
||||
g->updateConnection();
|
||||
Graph graph = new Graph(g->getOperators());
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
|
@ -0,0 +1,35 @@
|
|||
#include "nnet/permutation.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace nnet {
|
||||
|
||||
PermutationGenerator::PermutationGenerator(vector<vector<Iterator>> _from,
|
||||
vector<vector<Iterator>> _to)
|
||||
: from(_from), to(_to), mapping(from.size()) {
|
||||
assert(from.size() == to.size());
|
||||
for (size_t i = 0; i < from.size(); ++i)
|
||||
for (size_t j = 0; j < from[i].size(); ++j)
|
||||
mapping[i].emplace_back(j);
|
||||
}
|
||||
|
||||
bool PermutationGenerator::next() {
|
||||
if (mapping.empty())
|
||||
return false;
|
||||
for (int i = (int)mapping.size() - 1; i >= 0; --i) {
|
||||
if (std::next_permutation(mapping[i].begin(), mapping[i].end()))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
PtrMap<Iterator, Iterator> PermutationGenerator::get() const {
|
||||
if (mapping.empty())
|
||||
return {};
|
||||
PtrMap<Iterator, Iterator> ret;
|
||||
for (size_t i = 0; i < mapping.size(); ++i)
|
||||
for (size_t j = 0; j < mapping[i].size(); ++j)
|
||||
ret[from[i][j]] = to[i][mapping[i][j]];
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,92 @@
|
|||
#include "nnet/ReplaceKit.h"
|
||||
#include "nnet/Visitor/ReplaceVariable.h"
|
||||
#include "nnet/Visitor/SimplifyFormulaMutator.h"
|
||||
|
||||
namespace nnet {
|
||||
|
||||
RangeOp ReplaceKit::replaceRangeOpIterator(const RangeOp &rangeOp,
|
||||
const Replace &replace,
|
||||
const Expr &replacedSummand) {
|
||||
vector<VarRangePair> newVarRangePairs(replace.newVarRanges);
|
||||
if (replace.iteratorType == IterationType::Loop) {
|
||||
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
|
||||
if (!replace.isReplaced(var))
|
||||
newVarRangePairs.emplace_back(var, range);
|
||||
}
|
||||
assert(newVarRangePairs.size() == rangeOp->getLoopVarRanges().size() -
|
||||
replace.oldIters.size() +
|
||||
replace.newIters.size());
|
||||
// Check the number of loop iterators
|
||||
return makeRangeOperator(newVarRangePairs, rangeOp->getSumVarRanges(),
|
||||
replacedSummand);
|
||||
} else if (replace.iteratorType == IterationType::Sum) {
|
||||
for (const auto &[var, range] : rangeOp->getSumVarRanges()) {
|
||||
if (!replace.isReplaced(var))
|
||||
newVarRangePairs.emplace_back(var, range);
|
||||
}
|
||||
assert(newVarRangePairs.size() == rangeOp->getSumVarRanges().size() -
|
||||
replace.oldIters.size() +
|
||||
replace.newIters.size());
|
||||
return makeRangeOperator(rangeOp->getLoopVarRanges(), newVarRangePairs,
|
||||
replacedSummand, rangeOp->getPaddings());
|
||||
}
|
||||
assert(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Subscript ReplaceKit::buildSubscirptForLoopVarReplace(const RangeOp &inner,
|
||||
const Replace &replace) {
|
||||
VecExpr subs(replace.phis);
|
||||
for (size_t i = 0; i < replace.newVarRanges.size(); ++i) {
|
||||
assert(replace.newIters[i]->equal(inner->getLoopVar(i)));
|
||||
}
|
||||
for (size_t i = replace.newVarRanges.size();
|
||||
i < inner->getLoopVarRanges().size(); ++i) {
|
||||
subs.emplace_back(inner->getLoopVar(i));
|
||||
}
|
||||
// The support of var reorder and replace at the same time
|
||||
// VecExpr subs;
|
||||
// for (size_t i = 0; i < inner->getLoopVarRanges().size(); ++i) {
|
||||
// if (auto it = std::find(replace.newIters.begin(),
|
||||
// replace.newIters.end(),
|
||||
// inner->getLoopVar(i));
|
||||
// it != replace.newIters.end()) {
|
||||
// subs.emplace_back(replace.phis[it - replace.newIters.begin()]);
|
||||
// } else
|
||||
// subs.emplace_back(inner->getLoopVar(i));
|
||||
// }
|
||||
return makeSubscript(inner, subs);
|
||||
}
|
||||
|
||||
RangeOp
|
||||
ReplaceKit::buildDLTOuterRangeOp(const RangeOp &original,
|
||||
const Subscript &subscriptedNewRangeOp) {
|
||||
auto outer = make_ref<RangeOpNode>(*original);
|
||||
outer->setSummand(subscriptedNewRangeOp);
|
||||
outer->setSumIterator({});
|
||||
return outer;
|
||||
}
|
||||
|
||||
Expr ReplaceKit::replaceMultipleExprs(const Expr &cur,
|
||||
const vector<Var> &patterns,
|
||||
const VecExpr &replacements,
|
||||
bool simplify) {
|
||||
auto ret = cur;
|
||||
for (size_t i = 0; i < patterns.size(); ++i) {
|
||||
ret = replaceExpr(ret, patterns[i], replacements[i]);
|
||||
}
|
||||
if (simplify) {
|
||||
SimplifyFormulaMutator simplifyFormulaMutator;
|
||||
ret = simplifyFormulaMutator.simplify(ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Expr ReplaceKit::replaceExpr(const Expr &cur, const Expr &pattern,
|
||||
const Expr &replacement) {
|
||||
auto replace = ReplaceVariable(pattern, replacement);
|
||||
auto ret = replace(cur);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,136 @@
|
|||
#include "nnet/routine.h"
|
||||
#include "nnet/Visitor/CloneMutator.h"
|
||||
#include <algorithm>
|
||||
namespace nnet {
|
||||
|
||||
RoutineNode::RoutineNode(Expr _expr, const vector<Tensor> &_inputs)
|
||||
: inputs(_inputs) {
|
||||
if (_expr)
|
||||
expr = CloneMutator().clone(_expr);
|
||||
else
|
||||
expr = nullptr;
|
||||
}
|
||||
|
||||
string MatmulNode::toReadable() const {
|
||||
std::ostringstream oss;
|
||||
assert(inputs.size() == 2);
|
||||
oss << "Matmul{bmnk = " << b << ", " << m << ", " << n << ", " << k
|
||||
<< "; AB = " << inputs[0] << ", " << inputs[1]
|
||||
<< "; transAB = " << transa << ", " << transb << "}";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool operator==(const MatmulNode &lhs, const MatmulNode &rhs) {
|
||||
if (!((lhs.b == rhs.b) && lhs.m == rhs.m && lhs.n == rhs.n &&
|
||||
lhs.k == rhs.k && lhs.transa == rhs.transa &&
|
||||
lhs.transb == rhs.transb))
|
||||
return false;
|
||||
if (lhs.inputs.size() != rhs.inputs.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < lhs.inputs.size(); ++i) {
|
||||
if (lhs.inputs[i]->getName() != rhs.inputs[i]->getName())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
string ConvNode::toReadable() const {
|
||||
std::ostringstream oss;
|
||||
assert(inputs.size() == 2);
|
||||
oss << "Conv{A =" << inputs[0]
|
||||
<< " shape=" << serializeVec(inputs[0]->getShape())
|
||||
<< ", K=" << inputs[1]
|
||||
<< " shape=" << serializeVec(inputs[1]->getShape()) << ", p = " << ph
|
||||
<< ", " << pw << ", s= " << sh << ", " << sw << ", d= " << dh << ", "
|
||||
<< dw << "}";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
bool operator==(const ConvNode &lhs, const ConvNode &rhs) {
|
||||
if (!(lhs.ph == rhs.ph && lhs.pw == rhs.pw && lhs.sh == rhs.sh &&
|
||||
lhs.sw == rhs.sw && lhs.dh == rhs.dh && lhs.dw == rhs.dw))
|
||||
return false;
|
||||
if (lhs.inputs.size() != rhs.inputs.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < lhs.inputs.size(); ++i) {
|
||||
if (lhs.inputs[i]->getName() != rhs.inputs[i]->getName())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<int> ConvNode::getShape() const {
|
||||
auto input = inputs[0], weight = inputs[1];
|
||||
auto n = input->getShape(0);
|
||||
auto h = input->getShape(2);
|
||||
auto w = input->getShape(3);
|
||||
auto f = weight->getShape(0);
|
||||
auto r = weight->getShape(2);
|
||||
auto s = weight->getShape(3);
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// Set padding size
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
ow = (w - (s - sw) * dw + pw * 2) / sw;
|
||||
auto ret = {on, oc, oh, ow};
|
||||
return ret;
|
||||
}
|
||||
|
||||
ConvArgs ConvNode::getArgs() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
|
||||
vector<int> G2bmmNode::getShape() const { return {b, m, 2 * w + 1}; }
|
||||
|
||||
vector<int> GbmmNode::getShape() const { return {b, m, n}; }
|
||||
|
||||
string ElementWiseNode::toReadable() const {
|
||||
std::ostringstream oss;
|
||||
oss << "EleWise{";
|
||||
for (const auto &input : inputs)
|
||||
oss << input << ", ";
|
||||
oss << "}";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double ElementWiseNode::getEstimatedTime() const {
|
||||
int64_t cntElements = 0;
|
||||
// For unimplemented transpose
|
||||
assert(inputs.size() > 0);
|
||||
if (!expr) {
|
||||
assert(inputs.size() == 1);
|
||||
}
|
||||
for (const auto &input : inputs)
|
||||
cntElements += input->getSize();
|
||||
int64_t outputSize = 1;
|
||||
for (const auto &len : outputShape)
|
||||
outputSize *= len;
|
||||
cntElements += outputSize;
|
||||
|
||||
const double bandwidth = 200 * 1000000;
|
||||
// dbg(inputs, inputs[0]->getShape(), cntElements,
|
||||
// (cntElements * 4) / bandwidth);
|
||||
return double(cntElements * 4) / bandwidth; // ms
|
||||
}
|
||||
|
||||
string G2bmmNode::toReadable() const {
|
||||
std::ostringstream oss;
|
||||
oss << "G2bmm{";
|
||||
for (const auto &input : inputs)
|
||||
oss << input << ", ";
|
||||
oss << ", bmwk = " << b << " " << m << " " << w << " " << k << "}";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
string GbmmNode::toReadable() const {
|
||||
std::ostringstream oss;
|
||||
oss << "Gbmm{";
|
||||
for (const auto &input : inputs)
|
||||
oss << input << ", ";
|
||||
oss << ", bmwn = " << b << " " << m << " " << w << " " << n << "}";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
G2bmmArgs G2bmmNode::getArgs() const { return {b, m, w, k, 1}; }
|
||||
|
||||
GbmmArgs GbmmNode::getArgs() const { return {b, m, w, n, 1}; }
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,70 @@
|
|||
#include "nnet/Visitor/FullPrinterVisitor.h"
|
||||
#include "nnet/Visitor/GetTensorsVisitor.h"
|
||||
#include "nnet/Visitor/HashVisitor.h"
|
||||
#include "nnet/Visitor/Interpreter.h"
|
||||
#include "nnet/Visitor/Serializer.h"
|
||||
namespace nnet {
|
||||
|
||||
int matchExprResult(Derivator &derivator, string fn) {
|
||||
auto ans = Serializer().deserialize(fn);
|
||||
auto hashAns = HashVisitor()(ans);
|
||||
int match = 0;
|
||||
for (const auto &candidate : derivator.getCandidates()) {
|
||||
auto hashCandidate = HashVisitor()(candidate.root);
|
||||
match += (hashAns == hashCandidate);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
bool checkExprLogSame(string fnPrefix, int start, int end) {
|
||||
Serializer serializer;
|
||||
string fn0 = fnPrefix + to_string(start) + ".expr";
|
||||
Expr expr0 = serializer.deserialize(fn0);
|
||||
RangeOp range0 = as<RangeOpNode>(expr0);
|
||||
Interpreter interpreter(range0);
|
||||
auto ans0 = interpreter.interpretUniformSample(range0);
|
||||
dbg(expr0, ans0);
|
||||
for (int i = start + 1; i < end; ++i) {
|
||||
string fn1 = fnPrefix + to_string(i) + ".expr";
|
||||
Expr expr1 = serializer.deserialize(fn1);
|
||||
RangeOp range1 = as<RangeOpNode>(expr1);
|
||||
dbg(fn1, expr1);
|
||||
auto ans1 = interpreter.interpretUniformSample(range1);
|
||||
dbg(ans1);
|
||||
if (ans0.size() != ans1.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < ans0.size(); ++i)
|
||||
if (ans0[i] != ans1[i])
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool checkExprsEquvivalence(VecExpr exprs) {
|
||||
if (exprs.size() < 2)
|
||||
return true;
|
||||
auto inputsMap0 = GetTensorsVisitor().get(exprs[0]);
|
||||
RangeOp range0 = as<RangeOpNode>(exprs[0]);
|
||||
Interpreter interpreter(range0);
|
||||
auto ans0 = interpreter.interpretUniformSample(range0);
|
||||
for (size_t i = 1; i + 1 < exprs.size(); ++i) {
|
||||
RangeOp range1 = as<RangeOpNode>(exprs[i]);
|
||||
auto inputsMap1 = GetTensorsVisitor().get(range1);
|
||||
// if expr0 and expr1 have different inputs, skip and return true
|
||||
if (inputsMap0.size() != inputsMap1.size())
|
||||
return true;
|
||||
for (const auto &[name, tensor] : inputsMap0) {
|
||||
if (!inputsMap1.count(name))
|
||||
return true;
|
||||
}
|
||||
auto ans1 = interpreter.interpretUniformSample(range1);
|
||||
if (ans0.size() != ans1.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < ans0.size(); ++i)
|
||||
if (ans0[i] != ans1[i])
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,12 @@
|
|||
#include "nnet/common.h"
|
||||
namespace nnet {
|
||||
|
||||
std::string pointer_to_hex(void *i) {
|
||||
std::stringstream stream;
|
||||
// stream << "0x" << std::setfill('0') << std::setw(sizeof(void *) * 2) <<
|
||||
// std::hex
|
||||
// << i;
|
||||
stream << std::hex << i;
|
||||
return stream.str();
|
||||
}
|
||||
} // namespace nnet
|
|
@ -0,0 +1,97 @@
|
|||
#include "nnet/visitor.h"
|
||||
namespace nnet {
|
||||
|
||||
Expr Mutator::visit_([[maybe_unused]] const Constant &c) { return nullptr; }
|
||||
|
||||
Expr Mutator::visit_(const BinaryOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
bool modified = false;
|
||||
auto ret = make_ref<BinaryOpNode>(*c);
|
||||
if (auto e = this->dispatch(ret->getLhs()); e) {
|
||||
modified = true;
|
||||
ret->setLhs(e);
|
||||
}
|
||||
if (auto e = this->dispatch(ret->getRhs()); e) {
|
||||
modified = true;
|
||||
ret->setRhs(e);
|
||||
}
|
||||
return (modified) ? ret : nullptr;
|
||||
}
|
||||
|
||||
Expr Mutator::visit_(const RangeOp &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
bool modified = false;
|
||||
auto ret = make_ref<RangeOpNode>(*c);
|
||||
if (auto mutated = this->dispatch(ret->getSummand()); mutated) {
|
||||
modified = true;
|
||||
ret->setSummand(mutated);
|
||||
}
|
||||
// NOT visit iterators and its ranges
|
||||
return (modified) ? ret : nullptr;
|
||||
}
|
||||
|
||||
Expr Mutator::visit_(const Subscript &c) {
|
||||
if (verbose)
|
||||
dbg(*c);
|
||||
bool modified = false;
|
||||
auto ret = make_ref<SubscriptNode>(*c);
|
||||
for (size_t i = 0; i < ret->getDims(); ++i) {
|
||||
const auto &e = ret->getIndex(i);
|
||||
if (const auto &mutated = this->dispatch(e); mutated) {
|
||||
modified = true;
|
||||
ret->setIndex(i, mutated);
|
||||
}
|
||||
}
|
||||
if (auto mutated = this->dispatch(ret->getObject()); mutated) {
|
||||
modified = true;
|
||||
ret->setObject(mutated);
|
||||
}
|
||||
return (modified) ? ret : nullptr;
|
||||
}
|
||||
|
||||
Expr Mutator::visit_([[maybe_unused]] const Var &c) { return nullptr; }
|
||||
|
||||
Expr Mutator::visit_([[maybe_unused]] const Tensor &c) { return nullptr; }
|
||||
|
||||
Expr Mutator::visit_(const Func &c) {
|
||||
if (verbose)
|
||||
dbg(c);
|
||||
bool modified = false;
|
||||
auto ret = make_ref<FuncNode>(*c);
|
||||
if (auto mutated = dispatch(c->getObject())) {
|
||||
modified = true;
|
||||
ret->setObject(mutated);
|
||||
}
|
||||
return (modified) ? ret : nullptr;
|
||||
}
|
||||
|
||||
void ExprTreeVisitor::visit_(const RangeOp &c) {
|
||||
if (inRange)
|
||||
dispatch(c->getSummand());
|
||||
}
|
||||
void ExprTreeVisitor::visit_(const BinaryOp &c) {
|
||||
if (inBinary) {
|
||||
dispatch(c->getLhs());
|
||||
dispatch(c->getRhs());
|
||||
}
|
||||
}
|
||||
void ExprTreeVisitor::visit_(const Subscript &c) {
|
||||
if (inSub) {
|
||||
dispatch(c->getObject());
|
||||
for (const auto &index : c->getIndex())
|
||||
dispatch(index);
|
||||
}
|
||||
}
|
||||
void ExprTreeVisitor::visit_([[maybe_unused]] const Var &c) {}
|
||||
void ExprTreeVisitor::visit_([[maybe_unused]] const Constant &c) {}
|
||||
void ExprTreeVisitor::visit_(const Tensor &c) {
|
||||
if (inTensor && c->getSource()) {
|
||||
if (const auto &expr = c->getSource()->getExpr(); expr)
|
||||
dispatch(expr);
|
||||
}
|
||||
}
|
||||
void ExprTreeVisitor::visit_(const Func &c) { dispatch(c->getObject()); }
|
||||
|
||||
} // namespace nnet
|
|
@ -0,0 +1,49 @@
|
|||
#include "code_engine.h"
|
||||
#include "nnet/expr.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operator.h"
|
||||
#include "search_engine.h"
|
||||
#include "tensor.h"
|
||||
#include "gtest/gtest.h"
|
||||
using namespace nnet;
|
||||
using namespace std;
|
||||
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
|
||||
|
||||
TEST(Activation, Relu) {
|
||||
const int n_heads = 8, seq_len = 10000, feat_len = 512;
|
||||
// dilation_heads = 2;
|
||||
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32;
|
||||
DEFINE_VAR(b);
|
||||
DEFINE_VAR(m);
|
||||
DEFINE_VAR(w);
|
||||
DEFINE_VAR(k);
|
||||
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, 0, 0});
|
||||
|
||||
auto subA = makeSubscript(A, {b, m, k});
|
||||
auto innerRange = makeRangeOperator(
|
||||
{{b, {0, Batch}}, {m, {0, M}}, {k, {0, K}}}, {}, subA);
|
||||
auto outerSub = makeSubscript(innerRange, {b, m, k});
|
||||
// auto subB = makeSubscript(B, {b, m + dilation * (w - W), k});
|
||||
auto relu = make_ref<FuncNode>(subA, FuncType::Relu);
|
||||
auto range =
|
||||
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {0, 2 * W + 1}}},
|
||||
{{k, {0, K}}}, relu);
|
||||
dbg(range);
|
||||
|
||||
auto g = new tpm::Graph();
|
||||
auto i0 = g->tensor({Batch, M, K});
|
||||
auto i1 = g->tensor({Batch, M, 2 * W + 1});
|
||||
|
||||
tpm::TensorVec inputsT{i0};
|
||||
tpm::TensorVec outputsT{i1};
|
||||
g->membound(inputsT, outputsT, {A}, range, 0);
|
||||
|
||||
g->updateConnection();
|
||||
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
|
||||
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
|
||||
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>());
|
||||
searchEngine.run(graph, bestGraph);
|
||||
tpm::CodeEngine codeEngine;
|
||||
codeEngine.genCode(bestGraph, "res.cu");
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue