Add: nnet code

This commit is contained in:
Liyan Zheng 2022-08-08 16:02:07 +08:00
parent 1205240218
commit b7e2096a26
239 changed files with 69388 additions and 13 deletions

View File

@ -4,10 +4,10 @@ include(CMakeDependentOption)
project(InfiniTensor C CXX) project(InfiniTensor C CXX)
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them. # Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
option(BUILD_TEST "Build tests" OFF) option(BUILD_TEST "Build tests" ON)
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF)
set(DEFAULT_BUILD_TYPE "RelWithDebInfo") set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
@ -31,16 +31,24 @@ if(OpenMP_CXX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif() endif()
# add_subdirectory(3rd-party/pybind11)
# add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
include_directories(include) include_directories(include)
# # Pybind11
# add_subdirectory(3rd-party/pybind11)
# include_directories(3rd-party/pybind11/include) # include_directories(3rd-party/pybind11/include)
# include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
# nlohmann_json
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
if(BUILD_TEST) if(BUILD_TEST)
set(BUILD_GMOCK off) # TODO: chekc set
set(INSTALL_GTEST off) set(BUILD_GMOCK
OFF
CACHE BOOL "Do not build gmock" FORCE)
set(INSTALL_GTEST
OFF
CACHE BOOL "Do not install gtest" FORCE)
add_subdirectory(3rd-party/googletest) add_subdirectory(3rd-party/googletest)
include_directories(3rd-party/googletest/googletest/include) include_directories(3rd-party/googletest/googletest/include)
endif() endif()
@ -59,13 +67,26 @@ add_library(InfiniTensor SHARED ${SRC})
# pybind11_add_module(infini MODULE ${FFI}) # pybind11_add_module(infini MODULE ${FFI})
# target_link_libraries(infini PRIVATE infini_cpp) # target_link_libraries(infini PRIVATE infini_cpp)
if(BUILD_TEST) function(build_test files)
enable_testing() # Non-recursive glob for skip failed tests
file(GLOB_RECURSE TEST_SOURCES test/*.cc) file(GLOB TEST_SOURCES ${files})
foreach(testsourcefile ${TEST_SOURCES}) foreach(testsourcefile ${TEST_SOURCES})
get_filename_component(testname ${testsourcefile} NAME_WE) get_filename_component(testname ${testsourcefile} NAME_WE)
add_executable(${testname} ${testsourcefile}) add_executable(${testname} ${testsourcefile})
target_link_libraries(${testname} InfiniTensor gtest_main gtest) target_link_libraries(${testname} InfiniTensor GTest::gtest_main)
add_test(NAME ${testname} COMMAND ${testname}) add_test(NAME ${testname} COMMAND ${testname})
endforeach(testsourcefile ${TEST_SOURCES}) endforeach(testsourcefile ${TEST_SOURCES})
endfunction()
if(BUILD_TEST)
enable_testing()
if(BUILD_TEST_CORE)
build_test(test/core/*.cc)
endif()
if(BUILD_TEST_PET)
build_test(test/pet/*.cc)
endif()
if(BUILD_TEST_EINNET)
build_test(test/nnet/*.cc)
endif()
endif() endif()

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class MatchComputationKernel : public Pass {
public:
MatchComputationKernel(Derivator &derivator)
: Pass(derivator, "MatchComputationKernel") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class MatchMemBoundKernel : public Pass {
public:
MatchMemBoundKernel(Derivator &derivator)
: Pass(derivator, "MatchMemBoundKernel") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

41
include/nnet/Pass/Pass.h Normal file
View File

@ -0,0 +1,41 @@
#pragma once
#include "nnet/derivator.h"
namespace nnet {
class Pass {
private:
VecExpr transformations;
protected:
Derivator &derivator;
string passName;
/**
* @brief // False if does not add log in Derivator. It should be false for
* single Pass test to avoid mismatch of passInfos and passMsgs due to
* different number of "run" and "nextStep".
*/
bool enableLogging, enableDebug;
virtual void transform(Formula &origin, int depth, Expr &rCur) = 0;
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
const string &ruleInfo = "");
Var getNewVar();
string newTensorName();
private:
void initialize(Formula &origin, const Expr &rCur);
void finalize();
public:
Pass(Derivator &derivator, const string &passName);
virtual ~Pass();
void run(Formula &origin, int dfsDepth, Expr &rCur);
void setEnableLogging(bool value);
void setEnableDebug(bool value);
const VecExpr &getTransformations();
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/Pass/Pass.h"
#include "nnet/ReplaceKit.h"
namespace nnet {
class Rule1VariableSplit : public Pass {
public:
Rule1VariableSplit(Derivator &derivator)
: Pass(derivator, "Rule1VariableSplit") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
vector<Replace> getSplitableVar(const RangeOp &rangeOp);
Expr replaceIters(Expr cur, const Replace &replace);
};
} // namespace nnet

View File

@ -0,0 +1,29 @@
#pragma once
#include "nnet/Pass/Pass.h"
#include "nnet/ReplaceKit.h"
namespace nnet {
class Rule2VariableMerging : public Pass {
private:
map<int, vector<Var>> substituteRules;
public:
Rule2VariableMerging(Derivator &derivator)
: Pass(derivator, "Rule2VariableMerging") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
vector<Replace> getMergableReplaces(RangeOp rangeOp, int depth);
optional<Replace> getReplaceMergingTwoLoopIters(const RangeOp &rangeOp,
pair<Iterator, int> pairA,
pair<Iterator, int> pairB,
const IteratorTable &exprIT,
int tensorID);
optional<Replace> getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
pair<Iterator, int> pa,
pair<Iterator, int> pb);
};
} // namespace nnet

View File

@ -0,0 +1,19 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule3StageSplit : public Pass {
private:
map<int, vector<Var>> substituteRules;
public:
Rule3StageSplit(Derivator &derivator)
: Pass(derivator, "Rule3StageSplit") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
vector<vector<Var>> getSplitSummationIters(RangeOp rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,22 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule4StageMerging : public Pass {
bool success, mergeStageWithCalc;
public:
Rule4StageMerging(Derivator &derivator)
: Pass(derivator, "Rule4StageMerging"), success(false),
mergeStageWithCalc(false) {}
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
bool mergeStageWithCalc = false);
bool isSuccessful();
void setMergeStageWithCalc(bool value);
private:
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule5RangeRelaxation : public Pass {
public:
Rule5RangeRelaxation(Derivator &derivator)
: Pass(derivator, "Rule5RangeRelaxation") {}
Expr rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
private:
virtual void transform(Formula &origin, int depth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,17 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule6KenerlMatching : public Pass {
public:
Rule6KenerlMatching(Derivator &derivator)
: Pass(derivator, "Rule6KenerlMatching") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
// RE: seperating this func is a choice.
VecExpr matchElementWise(const RangeOp &rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule7DLT : public Pass {
public:
Rule7DLT(Derivator &derivator) : Pass(derivator, "Rule7DLT") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
Expr buildDLTSingleRangeOp(const RangeOp &original, const Expr &newSummand);
vector<int> getFactors();
};
} // namespace nnet

View File

@ -0,0 +1,48 @@
#pragma once
#include "nnet/Pass/Pass.h"
#include "nnet/ReplaceKit.h"
namespace nnet {
class Rule8GuidedDLT : public Pass {
public:
Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {}
VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur,
bool debug = false);
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
/**
* @brief If only one row miss match (more iterators mismatch), directly do
* data layout construction according to the IT.
*
* @return Expr Return nullptr if failed.
*/
Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch,
const IteratorTable &exprIT, const Pattern &pattern);
/**
* @brief Check whether two iterators overlap each other. If overlapping, we
* cannot simply reconstruct the tensor into a new one by seperate all
* iterators into different dimensions.
*/
bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT,
int tensorID);
/**
* @brief Only product of two tensors can be guided DLTed.
*
* @param cur
* @return true
* @return false
*/
bool statisfyGuidedDLT(RangeOp cur) const;
/**
* @brief Deal with output DLT mismatch only.
*/
Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch,
const IteratorTable &exprIT,
const Pattern &pattern);
Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace,
vector<Var> tensorDimAxes, vector<int> newShape);
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule90TwoStageElementWise : public Pass {
public:
Rule90TwoStageElementWise(Derivator &derivator)
: Pass(derivator, "Rule90TwoStageElementWise") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
VecExpr matchTwoStageElementWise(const RangeOp &rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule91MergeStagesWithSum : public Pass {
public:
Rule91MergeStagesWithSum(Derivator &derivator)
: Pass(derivator, "Rule91MergeStagesWithSum") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/Pass/Pass.h"
namespace nnet {
class Rule9RangeMagnify : public Pass {
public:
Rule9RangeMagnify(Derivator &derivator)
: Pass(derivator, "Rule9RangeMagnify") {}
private:
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
};
} // namespace nnet

46
include/nnet/ReplaceKit.h Normal file
View File

@ -0,0 +1,46 @@
#pragma once
#include "nnet/expr.h"
namespace nnet {
struct Replace {
int iteratorType;
vector<Var> oldIters; // i_1, ...
vector<Var> newIters; // j_1, ...
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
VecExpr psis; // i_1=\psi_1(j_1, ...)
vector<VarRangePair> newVarRanges;
bool isReplaced(Var var) const {
for (const auto &iter : oldIters)
if (iter->equal(var))
return true;
return false;
}
string toReadable() const {
string ret = "Old iters: " + serializeVec(oldIters) +
", new iters: " + serializeVec(newIters);
ret += " phis: " + serializeVec(phis) + " psis: " + serializeVec(psis);
return ret;
}
};
class ReplaceKit {
public:
static RangeOp replaceRangeOpIterator(const RangeOp &rangeOp,
const Replace &replace,
const Expr &replacedSummand);
static Subscript buildSubscirptForLoopVarReplace(const RangeOp &inner,
const Replace &replace);
static RangeOp buildDLTOuterRangeOp(const RangeOp &original,
const Subscript &subscriptedNewRangeOp);
static Expr replaceMultipleExprs(const Expr &cur,
const vector<Var> &patterns,
const VecExpr &replacements,
bool simplify = true);
static Expr replaceExpr(const Expr &cur, const Expr &pattern,
const Expr &replacement);
};
} // namespace nnet

View File

@ -0,0 +1,38 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
class AsTVMVisitor : public Functor<std::string(void)> {
private:
int nStage = 0, curStage = -1;
std::unordered_map<std::string, int> offset;
std::vector<std::string> inputs;
std::string output;
std::vector<std::string> pythonVars;
std::vector<std::vector<int>> inputShapes;
std::vector<int> outputShape;
std::string stmts;
public:
std::string getStmts() const;
const std::vector<std::string> &getInputs() const { return inputs; }
const std::string &getOutput() const { return output; }
const std::vector<std::vector<int>> &getInputShapes() const {
return inputShapes;
}
const std::vector<int> &getOutputShape() const { return outputShape; }
std::string visit_(const Constant &c) override;
std::string visit_(const BinaryOp &c) override;
std::string visit_(const Func &c) override;
std::string visit_(const RangeOp &c) override;
std::string visit_(const Subscript &c) override;
std::string visit_(const Var &c) override;
std::string visit_(const Tensor &c) override;
};
} // namespace nnet

View File

@ -0,0 +1,22 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class CheckOOBVisitor : public ExprTreeVisitor {
RangeOp rangeOp;
bool detect = false;
public:
CheckOOBVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 0, 0, _verobse) {}
void visit_(const Subscript &c) override;
/**
* @brief
* @return true If there is OOB
* @return false If there is no OOB
*/
bool checkRangeOp(const RangeOp &_rangeOp);
};
} // namespace nnet

View File

@ -0,0 +1,16 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Clone ExprNodes in a stage except Tensor, Var, and Constant nodes.
class CloneMutator : public Mutator {
public:
CloneMutator() : Mutator(false) {}
Expr visit_(const Constant &c) override;
Expr visit_(const Var &c) override;
Expr visit_(const Tensor &c) override;
Expr clone(const Expr &c) { return dispatch(c); }
};
} // namespace nnet

View File

@ -0,0 +1,15 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class CompareMultiFormulasVisitor : public ExprTreeVisitor {
vector<VarRangePair> newSumVarRanges;
RangeOp newRangeOp;
public:
CompareMultiFormulasVisitor() : ExprTreeVisitor() {}
bool compare(const VecExpr &roots);
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class CountRoutineVisitor : public ExprTreeVisitor {
private:
vector<int> cnts;
public:
CountRoutineVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 1, _verobse) {}
void visit_(const Tensor &c) override;
vector<int> count(const Expr &root);
bool match(const Expr &root, int nMatmul = 0, int nConv = 0,
int nElement = 0, int nSg2bmm = 0, int nLongformerGBMM = 0);
};
} // namespace nnet

View File

@ -0,0 +1,25 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class FullPrinterVisitor : public ExprTreeVisitor {
private:
vector<tuple<string, Routine, Tensor>> q;
public:
FullPrinterVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
void visit_(const Tensor &c) override;
string print(const Expr &root);
/**
* @brief Get all tensors & OPs in a reversed order
*
* @param root
* @return vector<<Output TensorName, RoutineNode, output tensor in NNet>>
*/
const vector<tuple<string, Routine, Tensor>> &traverse(const Expr &root);
};
} // namespace nnet

View File

@ -0,0 +1,22 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Get all tensors in the stage
class GetTensorsVisitor : public ExprTreeVisitor {
private:
unordered_map<string, Tensor> tensors;
void visit_(const Tensor &c) override;
public:
GetTensorsVisitor(int _verobse = 0)
: ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
auto get(const Expr &c) {
dispatch(c);
return tensors;
}
};
} // namespace nnet

View File

@ -0,0 +1,31 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Calculate hash for a normal form, starting at a RangeOp
class HashVisitor : public Functor<HashType(void)> {
inline const static HashType BKDR_SEED[] = {131, 313, 10007, 65599};
PtrUmap<Iterator, int> varHash;
int nLoopVars = 0;
PtrUmap<Iterator, int> name2id;
vector<int> rootId;
vector<bool> haveAlias;
int nVars = 0;
vector<HashType> power;
private:
HashType visit_(const Constant &c) override;
HashType visit_(const BinaryOp &c) override;
HashType visit_(const RangeOp &c) override;
HashType visit_(const Subscript &c) override;
HashType visit_(const Tensor &c) override;
HashType visit_(const Var &c) override;
public:
HashVisitor(int _verobse = 0) : Functor(_verobse) {}
HashType getHash(const Expr &c);
};
} // namespace nnet

View File

@ -0,0 +1,23 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class InputVisitor : public ExprTreeVisitor {
vector<Tensor> inputs;
public:
int nInputs = 0;
InputVisitor(int _verobse = 0) : ExprTreeVisitor(1, 1, 1, 0, _verobse) {}
void visit_(const Tensor &c) override;
/**
* @brief Get the all inputs in the netsed stages
*/
vector<Tensor> getInputs(const RangeOp &_rangeOp) {
dispatch(_rangeOp);
return inputs;
}
};
} // namespace nnet

View File

@ -0,0 +1,55 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class Interpreter : public Functor<int()> {
public:
using ttype = int; // Test data type
using rtype = int; // Return data type
using Position = vector<int>;
using Inputs = unordered_map<string, Ref<vector<ttype>>>;
using Iteration = PtrUmap<Var, int>;
private:
// cache the input value
Inputs inputs;
vector<Iteration> iterations;
vector<Position> positions;
rtype visit_(const Constant &c) override;
rtype visit_(const BinaryOp &c) override;
rtype visit_(const RangeOp &c) override;
rtype visit_(const Subscript &c) override;
rtype visit_(const Var &c) override;
rtype visit_(const Tensor &c) override;
// int visit_(const Func &c); // Future work
static Inputs genInputStartingFromZero(const RangeOp &range);
public:
Interpreter(Inputs _inputs, int _verbose = 0)
: Functor(_verbose), inputs(_inputs) {}
Interpreter(RangeOp range, int _verbose = 0);
/**
* @brief Calculate the output at specified poistions
*
* @param expr The expression to be calculated.
* @param poses Positions of output.
* @return vector<int> Value of output.
*/
vector<rtype> interpret(const Expr &expr, const vector<Position> &poses);
/**
* @brief Calculate the output at equally spaced positions
*
* @param expr The expression to be calculated.
* @param nPoses The number of calculated output positions.
* @return vector<int> Value of output.
*/
vector<rtype> interpretUniformSample(const RangeOp &range,
int nPoses = 100);
vector<rtype> interpretAllOutput(const RangeOp &range);
};
} // namespace nnet

View File

@ -0,0 +1,14 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class MatchReshapeVisitor : public Functor<bool(void)> {
private:
PtrMap<Iterator, int> _coefficient;
public:
bool visit_(const RangeOp &c) override;
};
} // namespace nnet

View File

@ -0,0 +1,60 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
class MatchTableVisitor
: public Functor<void(const Tensor &, int dim, optional<int> stride)> {
private:
// Var -> {(tensor, dim)}
Appearance appearance;
vector<Tensor> tensors;
vector<Subscript> subscripts;
StrideTable strideTable;
PtrMap<Iterator, vector<vector<int>>>
strideInDim; // [Iterator][tensorID][dim]=stride
// Intermediate variable
// product of a sub-exprtree: Stride has to be done in two DFS
SubexprSride subexprStride;
bool hasUnsupportedOp = false;
public:
MatchTableVisitor(int _verobse = 0) : Functor(_verobse) {}
void visit_(const BinaryOp &c, const Tensor &tensor, int dim,
optional<int> stride) override;
void visit_(const Subscript &c, const Tensor &tensor, int dim,
optional<int> stride) override;
void visit_(const Var &c, const Tensor &tensor, int dim,
optional<int> stride) override;
void visit_(const Constant &c, const Tensor &tensor, int dim,
optional<int> stride) override;
// void visit_(const Tensor &c, const Tensor &tensor) override;
[[nodiscard]] bool operator()(const RangeOp &e) {
hasUnsupportedOp = false;
// get the location and stride of each iterator
auto mulOp = as<BinaryOpNode>(e->getSummand());
// TODO [feature]: support complex index exprs
if (!mulOp || mulOp->getOpType() != OpType::Mul) {
nnet_unimplemented_continue();
return false;
}
StrideVisitor strideVisitor(0);
subexprStride = strideVisitor.getFormulaStride(e);
dispatch(mulOp->getLhs(), nullptr, 0, 0);
dispatch(mulOp->getRhs(), nullptr, 0, 0);
subscripts.emplace_back(as<SubscriptNode>(mulOp->getLhs()));
subscripts.emplace_back(as<SubscriptNode>(mulOp->getRhs()));
assert(tensors.size() == subscripts.size());
assert(tensors.size() < 5);
return !hasUnsupportedOp;
}
auto getResult() const {
return tuple(appearance, tensors, strideTable, subscripts);
}
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class MatmulTransposeMutator : public Mutator {
Derivator &derivator;
public:
MatmulTransposeMutator(Derivator &derivator)
: Mutator(1), derivator(derivator) {}
VecExpr transpose(const Tensor &tensor);
private:
Tensor transposeInput(const Tensor &tensor);
};
} // namespace nnet

View File

@ -0,0 +1,20 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class MergeMemboundMutator : public Mutator {
VecExpr kernels;
int curDepth; // from the last one to the first one
Expr visit_(const Tensor &c) override;
// FIXME: duplicate code
Expr rule4StageMerging(Expr &rCur, bool mergeStageWithCalc);
bool checkEmpty();
public:
MergeMemboundMutator(const VecExpr &kernels)
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
Expr merge(bool allowEmptyMembound = false);
};
} // namespace nnet

View File

@ -0,0 +1,43 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
/**
* @brief Since the output positions of operators always start from 0, we have
* to offset them if the the boundary expression of is not 0.
*/
class PatternMatcher : public Functor<void(void)> {
private:
Derivator &derivator;
bool hasNonZeroRange;
const RangeOp originalCur;
public:
PatternMatcher(Derivator &derivator, const RangeOp &cur);
/**
* @brief Get the Cur whose loop vars are all offset to [0, x). Since
* operator outputs start from 0, RangeOp has to be aligned.
*/
RangeOp getOffsetCur();
/**
* @brief Add outer RangeOp to map the original positions to the new
* positions staring from 0.
*
* @param exprs Tensors from matched exprs
*/
VecExpr applyWrapper(const VecExpr &exprs);
VecExpr matchWithPattern(const RangeOp &rangeOp, const Pattern &pattern);
private:
VecExpr matchKernel(const Pattern &pattern, const RangeOp &rangeOp,
IteratorTable &exprIT);
// get reverse tensor and iterator map ([pattern tensor/iter ID] ->
// real)
Expr matchKernelWithTensorMap(const Pattern &pattern,
const RangeOp &rangeOp,
IteratorTable &exprIT);
};
} // namespace nnet

View File

@ -0,0 +1,25 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class RangeMagnifyVisitor : public Mutator {
vector<VarRangePair> newSumVarRanges;
RangeOp newRangeOp;
public:
RangeMagnifyVisitor() : Mutator(0) {}
Expr visit_(const RangeOp &c) override;
Expr visit_(const Subscript &c) override;
/**
* @brief
*
* @param root
* @param _newSumVarRanges
* @return RangeOp nullptr if failed to magnify
*/
RangeOp magnify(const RangeOp &root,
const vector<VarRangePair> &_newSumVarRanges);
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
using RangeMap = PtrMap<Iterator, Range>;
class RangeRelaxFunctor : public Functor<RangeMap()> {
RangeOp rangeOp;
public:
RangeRelaxFunctor(RangeOp _rangeOp) : Functor(false), rangeOp(_rangeOp) {}
RangeMap visit_(const BinaryOp &c) override;
RangeMap visit_(const RangeOp &c) override;
RangeMap visit_(const Subscript &c) override;
RangeMap intersectRangeMaps(const RangeMap &a, const RangeMap &b);
};
} // namespace nnet

View File

@ -0,0 +1,20 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
// Replace node according to its address the summand
// Only subscript and tensor are supported now.
class ReplaceNodeMutator : public Mutator {
int nSubscripts = 0;
ExprNode *target;
Expr replacement;
public:
ReplaceNodeMutator() : Mutator(0) {}
Expr visit_(const Subscript &c) override;
Expr visit_(const Tensor &c) override;
Expr replace(const Expr &root, ExprNode *_target, const Expr &_replace);
};
} // namespace nnet

View File

@ -0,0 +1,33 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
class ReplaceVariable : public Mutator {
VecExpr patterns, replacements;
map<HashType, int> patternHash;
public:
ReplaceVariable(Expr _pattern, Expr _replacement) : Mutator(false) {
set({_pattern}, {_replacement});
}
ReplaceVariable(const map<string, pair<Expr, Expr>> &mapping)
: Mutator(false) {
VecExpr _patterns, _replacements;
for (const auto &[_, v] : mapping) {
_patterns.emplace_back(v.first);
_replacements.emplace_back(v.second);
}
set(_patterns, _replacements);
}
Expr visit_(const BinaryOp &c) override;
// NOT recur to the next stage
Expr visit_(const RangeOp &c) override;
Expr visit_(const Var &c) override;
private:
void set(VecExpr _pattern, VecExpr _replacement);
Expr match(const Expr &c);
};
} // namespace nnet

View File

@ -0,0 +1,52 @@
#pragma once
#include "nlohmann/json_fwd.hpp"
#include "nnet/visitor.h"
#include <memory>
namespace nnet {
class Serializer : public Functor<string()> {
using json = nlohmann::ordered_json;
private:
static constexpr int VERSION{1};
std::unique_ptr<json> jPtr;
json &j;
static int id;
string visit_(const Constant &c) override;
string visit_(const BinaryOp &c) override;
string visit_(const RangeOp &c) override;
string visit_(const Subscript &c) override;
string visit_(const Var &c) override;
string visit_(const Tensor &c) override;
string dispatchRoutine(const Routine &c);
Expr buildExprTree(string key);
Routine buildRoutine(string key);
public:
Serializer(int _verobse = 0);
virtual ~Serializer();
/**
* @brief Serialize the given expression to json file
*
* @param expr The expression to be serialized
* @param filePath The path of json file to be output
* @param msg Message of derivation
* @return bool Whether the serialization succeed
*/
bool serialize(const Expr &expr, const string &filePath,
const string &msg = "");
/**
* @brief Deserialize the given json file to expression
*
* @param filePath The path to file to be deserialized
* @return Expression deserialized from the given json file
*/
Expr deserialize(const string &filePath);
};
} // namespace nnet

View File

@ -0,0 +1,38 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
// Simplify a index expression tree
class SimplifyExprVisitor : public Functor<void(optional<int> stride)> {
private:
SubexprSride subexprStride;
int constant;
PtrMap<Iterator, int> strides; // [var]=strides
map<pair<Iterator, int>, int, RefValueLess<pair<Iterator, int>>> divStrides,
modStrides; // 3*(i%8): [<i,8>]=3
// For divde and modulo with expr as dividend: 3*((i+1)%8): [<i+1,8>]=3
map<pair<Expr, int>, int, RefAddrLess<pair<Expr, int>>> divExprStrides,
modExprStrides;
public:
SimplifyExprVisitor() : Functor(0) {}
void visit_(const BinaryOp &c, optional<int> stride) override;
void visit_(const Var &c, optional<int> stride) override;
void visit_(const Constant &c, optional<int> stride) override;
PtrMap<Iterator, int> getStrides(const Expr &expr);
// TODO [refactor]: move this to SimplifyFormulaMutator as a member func
// this class should be get coefficients in a expr
Expr simplify(const Expr &expr);
int getConstant(const Expr &expr);
pair<PtrMap<Iterator, int>, int> getStridesConstant(const Expr &expr);
optional<Range> getExprRange(const Expr &expr, const RangeOp &rangeOp);
PtrMap<Iterator, int> getStrides() { return strides; }
const auto &getDivStrides() { return divStrides; }
const auto &getModStrides() { return modStrides; }
};
} // namespace nnet

View File

@ -0,0 +1,18 @@
#pragma once
#include "nnet/Visitor/StrideVisitor.h"
#include "nnet/visitor.h"
namespace nnet {
// Simplify all indexes in subscripts in an expression tree
class SimplifyFormulaMutator : public Mutator {
int nSubscripts = 0;
public:
SimplifyFormulaMutator() : Mutator(0) {}
Expr visit_(const Subscript &c) override;
// Expr visit_(const BinaryOp &c) override;
Expr simplify(const Expr &expr);
};
} // namespace nnet

View File

@ -0,0 +1,38 @@
#pragma once
#include "nnet/visitor.h"
namespace nnet {
using SubexprSride = map<const ExprNode *, optional<int>>;
class StrideVisitor : public Functor<optional<int>(void)> {
private:
SubexprSride subexprStride;
public:
StrideVisitor(int _verobse = 0) : Functor(_verobse) {}
optional<int> visit_(const BinaryOp &c) override;
optional<int> visit_(const Subscript &c) override;
optional<int> visit_(const Var &c) override;
optional<int> visit_(const Constant &c) override;
// void visit_(const Tensor &c, const Tensor &tensor) override;
auto getFormulaStride(const RangeOp &e) {
subexprStride.clear();
// get the location and stride of each iterator
auto mulOp = as<BinaryOpNode>(e->getSummand());
// TODO [feature]: support complex index exprs
if (!mulOp || mulOp->getOpType() != OpType::Mul)
nnet_unimplemented_continue();
dispatch(mulOp->getLhs());
dispatch(mulOp->getRhs());
return subexprStride;
}
[[nodiscard]] auto getExprStride(const Expr &e) {
subexprStride.clear();
dispatch(e);
return subexprStride;
}
};
} // namespace nnet

77
include/nnet/common.h Normal file
View File

@ -0,0 +1,77 @@
#pragma once
#include "dbg.h"
#include <cassert>
#include <list>
#include <map>
#include <optional>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace nnet {
using std::dynamic_pointer_cast;
using std::endl;
using std::list;
using std::make_pair;
using std::make_shared;
using std::make_tuple;
using std::map;
using std::max;
using std::min;
using std::nullopt;
using std::optional;
using std::pair;
using std::set;
using std::shared_ptr;
using std::string;
using std::tie;
using std::to_string;
using std::tuple;
using std::unique_ptr;
using std::unordered_map;
template <typename T> using uset = std::unordered_set<T>;
using std::vector;
using std::weak_ptr;
// Aliases
using dtype = float;
using HashType = int;
template <typename T> struct ptr_less {
bool operator()(const T &lhs, const T &rhs) const { return *lhs < *rhs; }
};
template <typename T> struct ptr_hash {
size_t operator()(const T &lhs) const {
return std::hash<decltype(*lhs)>()(*lhs);
}
};
template <typename T> struct ptr_equal {
bool operator()(const T &lhs, const T &rhs) const { return *lhs == *rhs; }
};
static inline HashType genhash(HashType a, HashType b) {
return (a * 10007 + b + 12345) % 1000000007;
}
static inline HashType genhash(string s) {
HashType ret = 0;
for (auto c : s)
ret = genhash(ret, c);
return ret;
}
#define nnet_unimplemented_halt() \
{ assert(!"Unimplemented"); }
#define nnet_unimplemented_continue() \
{ dbg("Unimplemented"); }
#define nnet_assert(expr, msg) assert(((void)(msg), (expr)))
std::string pointer_to_hex(void *i);
} // namespace nnet

893
include/nnet/dbg.h Normal file
View File

@ -0,0 +1,893 @@
/*****************************************************************************
dbg(...) macro
License (MIT):
Copyright (c) 2019 David Peter <mail@david-peter.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*****************************************************************************/
#ifndef DBG_MACRO_DBG_H
#define DBG_MACRO_DBG_H
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
#define DBG_MACRO_UNIX
#elif defined(_MSC_VER)
#define DBG_MACRO_WINDOWS
#endif
// #ifndef DBG_MACRO_NO_WARNING
// #pragma message("WARNING: the 'dbg.h' header is included in your code base")
// #endif // DBG_MACRO_NO_WARNING
#include <algorithm>
#include <chrono>
#include <ctime>
#include <iomanip>
#include <ios>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include <vector>
#ifdef DBG_MACRO_UNIX
#include <unistd.h>
#endif
#if __cplusplus >= 201703L
#define DBG_MACRO_CXX_STANDARD 17
#elif __cplusplus >= 201402L
#define DBG_MACRO_CXX_STANDARD 14
#else
#define DBG_MACRO_CXX_STANDARD 11
#endif
#if DBG_MACRO_CXX_STANDARD >= 17
#include <optional>
#include <variant>
#endif
namespace dbg {
#ifdef DBG_MACRO_UNIX
inline bool isColorizedOutputEnabled() {
return isatty(fileno(stderr));
}
#else
inline bool isColorizedOutputEnabled() {
return true;
}
#endif
struct time {};
namespace pretty_function {
// Compiler-agnostic version of __PRETTY_FUNCTION__ and constants to
// extract the template argument in `type_name_impl`
#if defined(__clang__)
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
static constexpr size_t PREFIX_LENGTH =
sizeof("const char *dbg::type_name_impl() [T = ") - 1;
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
#elif defined(__GNUC__) && !defined(__clang__)
#define DBG_MACRO_PRETTY_FUNCTION __PRETTY_FUNCTION__
static constexpr size_t PREFIX_LENGTH =
sizeof("const char* dbg::type_name_impl() [with T = ") - 1;
static constexpr size_t SUFFIX_LENGTH = sizeof("]") - 1;
#elif defined(_MSC_VER)
#define DBG_MACRO_PRETTY_FUNCTION __FUNCSIG__
static constexpr size_t PREFIX_LENGTH =
sizeof("const char *__cdecl dbg::type_name_impl<") - 1;
static constexpr size_t SUFFIX_LENGTH = sizeof(">(void)") - 1;
#else
#error "This compiler is currently not supported by dbg_macro."
#endif
} // namespace pretty_function
// Formatting helpers
template <typename T>
struct print_formatted {
static_assert(std::is_integral<T>::value,
"Only integral types are supported.");
print_formatted(T value, int numeric_base)
: inner(value), base(numeric_base) {}
operator T() const { return inner; }
const char* prefix() const {
switch (base) {
case 8:
return "0o";
case 16:
return "0x";
case 2:
return "0b";
default:
return "";
}
}
T inner;
int base;
};
template <typename T>
print_formatted<T> hex(T value) {
return print_formatted<T>{value, 16};
}
template <typename T>
print_formatted<T> oct(T value) {
return print_formatted<T>{value, 8};
}
template <typename T>
print_formatted<T> bin(T value) {
return print_formatted<T>{value, 2};
}
// Implementation of 'type_name<T>()'
template <typename T>
const char* type_name_impl() {
return DBG_MACRO_PRETTY_FUNCTION;
}
template <typename T>
struct type_tag {};
template <int&... ExplicitArgumentBarrier, typename T>
std::string get_type_name(type_tag<T>) {
namespace pf = pretty_function;
std::string type = type_name_impl<T>();
return type.substr(pf::PREFIX_LENGTH,
type.size() - pf::PREFIX_LENGTH - pf::SUFFIX_LENGTH);
}
template <typename T>
std::string type_name() {
if (std::is_volatile<T>::value) {
if (std::is_pointer<T>::value) {
return type_name<typename std::remove_volatile<T>::type>() + " volatile";
} else {
return "volatile " + type_name<typename std::remove_volatile<T>::type>();
}
}
if (std::is_const<T>::value) {
if (std::is_pointer<T>::value) {
return type_name<typename std::remove_const<T>::type>() + " const";
} else {
return "const " + type_name<typename std::remove_const<T>::type>();
}
}
if (std::is_pointer<T>::value) {
return type_name<typename std::remove_pointer<T>::type>() + "*";
}
if (std::is_lvalue_reference<T>::value) {
return type_name<typename std::remove_reference<T>::type>() + "&";
}
if (std::is_rvalue_reference<T>::value) {
return type_name<typename std::remove_reference<T>::type>() + "&&";
}
return get_type_name(type_tag<T>{});
}
inline std::string get_type_name(type_tag<short>) {
return "short";
}
inline std::string get_type_name(type_tag<unsigned short>) {
return "unsigned short";
}
inline std::string get_type_name(type_tag<long>) {
return "long";
}
inline std::string get_type_name(type_tag<unsigned long>) {
return "unsigned long";
}
inline std::string get_type_name(type_tag<std::string>) {
return "std::string";
}
template <typename T>
std::string get_type_name(type_tag<std::vector<T, std::allocator<T>>>) {
return "std::vector<" + type_name<T>() + ">";
}
template <typename T1, typename T2>
std::string get_type_name(type_tag<std::pair<T1, T2>>) {
return "std::pair<" + type_name<T1>() + ", " + type_name<T2>() + ">";
}
template <typename... T>
std::string type_list_to_string() {
std::string result;
auto unused = {(result += type_name<T>() + ", ", 0)..., 0};
static_cast<void>(unused);
#if DBG_MACRO_CXX_STANDARD >= 17
if constexpr (sizeof...(T) > 0) {
#else
if (sizeof...(T) > 0) {
#endif
result.pop_back();
result.pop_back();
}
return result;
}
template <typename... T>
std::string get_type_name(type_tag<std::tuple<T...>>) {
return "std::tuple<" + type_list_to_string<T...>() + ">";
}
template <typename T>
inline std::string get_type_name(type_tag<print_formatted<T>>) {
return type_name<T>();
}
// Implementation of 'is_detected' to specialize for container-like types
namespace detail_detector {
struct nonesuch {
nonesuch() = delete;
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <typename...>
using void_t = void;
template <class Default,
class AlwaysVoid,
template <class...>
class Op,
class... Args>
struct detector {
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, void_t<Op<Args...>>, Op, Args...> {
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail_detector
template <template <class...> class Op, class... Args>
using is_detected = typename detail_detector::
detector<detail_detector::nonesuch, void, Op, Args...>::value_t;
namespace detail {
namespace {
using std::begin;
using std::end;
#if DBG_MACRO_CXX_STANDARD < 17
template <typename T>
constexpr auto size(const T& c) -> decltype(c.size()) {
return c.size();
}
template <typename T, std::size_t N>
constexpr std::size_t size(const T (&)[N]) {
return N;
}
#else
using std::size;
#endif
} // namespace
template <typename T>
using detect_begin_t = decltype(detail::begin(std::declval<T>()));
template <typename T>
using detect_end_t = decltype(detail::end(std::declval<T>()));
template <typename T>
using detect_size_t = decltype(detail::size(std::declval<T>()));
template <typename T>
struct is_container {
static constexpr bool value =
is_detected<detect_begin_t, T>::value &&
is_detected<detect_end_t, T>::value &&
is_detected<detect_size_t, T>::value &&
!std::is_same<std::string,
typename std::remove_cv<
typename std::remove_reference<T>::type>::type>::value;
};
template <typename T>
using ostream_operator_t =
decltype(std::declval<std::ostream&>() << std::declval<T>());
template <typename T>
struct has_ostream_operator : is_detected<ostream_operator_t, T> {};
} // namespace detail
// Helper to dbg(…)-print types
template <typename T>
struct print_type {};
template <typename T>
print_type<T> type() {
return print_type<T>{};
}
// Forward declarations of "pretty_print"
template <typename T>
inline void pretty_print(std::ostream& stream, const T& value, std::true_type);
template <typename T>
inline void pretty_print(std::ostream&, const T&, std::false_type);
template <typename T>
inline typename std::enable_if<!detail::is_container<const T&>::value &&
!std::is_enum<T>::value,
bool>::type
pretty_print(std::ostream& stream, const T& value);
inline bool pretty_print(std::ostream& stream, const bool& value);
inline bool pretty_print(std::ostream& stream, const char& value);
template <typename P>
inline bool pretty_print(std::ostream& stream, P* const& value);
template <typename T, typename Deleter>
inline bool pretty_print(std::ostream& stream,
std::unique_ptr<T, Deleter>& value);
// template <typename T>
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value);
template <size_t N>
inline bool pretty_print(std::ostream& stream, const char (&value)[N]);
template <>
inline bool pretty_print(std::ostream& stream, const char* const& value);
template <typename... Ts>
inline bool pretty_print(std::ostream& stream, const std::tuple<Ts...>& value);
template <>
inline bool pretty_print(std::ostream& stream, const std::tuple<>&);
template <>
inline bool pretty_print(std::ostream& stream, const time&);
template <typename T>
inline bool pretty_print(std::ostream& stream,
const print_formatted<T>& value);
template <typename T>
inline bool pretty_print(std::ostream& stream, const print_type<T>&);
template <typename Enum>
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
pretty_print(std::ostream& stream, Enum const& value);
inline bool pretty_print(std::ostream& stream, const std::string& value);
#if DBG_MACRO_CXX_STANDARD >= 17
inline bool pretty_print(std::ostream& stream, const std::string_view& value);
#endif
template <typename T1, typename T2>
inline bool pretty_print(std::ostream& stream, const std::pair<T1, T2>& value);
#if DBG_MACRO_CXX_STANDARD >= 17
template <typename T>
inline bool pretty_print(std::ostream& stream, const std::optional<T>& value);
template <typename... Ts>
inline bool pretty_print(std::ostream& stream,
const std::variant<Ts...>& value);
#endif
template <typename Container>
inline typename std::enable_if<detail::is_container<const Container&>::value,
bool>::type
pretty_print(std::ostream& stream, const Container& value);
// Specializations of "pretty_print"
template <typename T>
inline void pretty_print(std::ostream& stream, const T& value, std::true_type) {
stream << value;
}
template <typename T>
inline void pretty_print(std::ostream&, const T&, std::false_type) {
static_assert(detail::has_ostream_operator<const T&>::value,
"Type does not support the << ostream operator");
}
template <typename T>
inline typename std::enable_if<!detail::is_container<const T&>::value &&
!std::is_enum<T>::value,
bool>::type
pretty_print(std::ostream& stream, const T& value) {
pretty_print(stream, value,
typename detail::has_ostream_operator<const T&>::type{});
return true;
}
inline bool pretty_print(std::ostream& stream, const bool& value) {
stream << std::boolalpha << value;
return true;
}
inline bool pretty_print(std::ostream& stream, const char& value) {
const bool printable = value >= 0x20 && value <= 0x7E;
if (printable) {
stream << "'" << value << "'";
} else {
stream << "'\\x" << std::setw(2) << std::setfill('0') << std::hex
<< std::uppercase << (0xFF & value) << "'";
}
return true;
}
template <typename P>
inline bool pretty_print(std::ostream& stream, P* const& value) {
if (value == nullptr) {
stream << "nullptr";
} else {
stream << value;
}
return true;
}
template <typename T, typename Deleter>
inline bool pretty_print(std::ostream& stream,
std::unique_ptr<T, Deleter>& value) {
pretty_print(stream, value.get());
return true;
}
// template <typename T>
// inline bool pretty_print(std::ostream& stream, std::shared_ptr<T>& value) {
// pretty_print(stream, value.get());
// stream << " (use_count = " << value.use_count() << ")";
// return true;
// }
template <size_t N>
inline bool pretty_print(std::ostream& stream, const char (&value)[N]) {
stream << value;
return false;
}
template <>
inline bool pretty_print(std::ostream& stream, const char* const& value) {
stream << '"' << value << '"';
return true;
}
template <size_t Idx>
struct pretty_print_tuple {
template <typename... Ts>
static void print(std::ostream& stream, const std::tuple<Ts...>& tuple) {
pretty_print_tuple<Idx - 1>::print(stream, tuple);
stream << ", ";
pretty_print(stream, std::get<Idx>(tuple));
}
};
template <>
struct pretty_print_tuple<0> {
template <typename... Ts>
static void print(std::ostream& stream, const std::tuple<Ts...>& tuple) {
pretty_print(stream, std::get<0>(tuple));
}
};
template <typename... Ts>
inline bool pretty_print(std::ostream& stream, const std::tuple<Ts...>& value) {
stream << "{";
pretty_print_tuple<sizeof...(Ts) - 1>::print(stream, value);
stream << "}";
return true;
}
template <>
inline bool pretty_print(std::ostream& stream, const std::tuple<>&) {
stream << "{}";
return true;
}
template <>
inline bool pretty_print(std::ostream& stream, const time&) {
using namespace std::chrono;
const auto now = system_clock::now();
const auto us =
duration_cast<microseconds>(now.time_since_epoch()).count() % 1000000;
const auto hms = system_clock::to_time_t(now);
const std::tm* tm = std::localtime(&hms);
stream << "current time = " << std::put_time(tm, "%H:%M:%S") << '.'
<< std::setw(6) << std::setfill('0') << us;
return false;
}
// Converts decimal integer to binary string
template <typename T>
std::string decimalToBinary(T n) {
const size_t length = 8 * sizeof(T);
std::string toRet;
toRet.resize(length);
for (size_t i = 0; i < length; ++i) {
const auto bit_at_index_i = static_cast<char>((n >> i) & 1);
toRet[length - 1 - i] = bit_at_index_i + '0';
}
return toRet;
}
template <typename T>
inline bool pretty_print(std::ostream& stream,
const print_formatted<T>& value) {
if (value.inner < 0) {
stream << "-";
}
stream << value.prefix();
// Print using setbase
if (value.base != 2) {
stream << std::setw(sizeof(T)) << std::setfill('0')
<< std::setbase(value.base) << std::uppercase;
if (value.inner >= 0) {
// The '+' sign makes sure that a uint_8 is printed as a number
stream << +value.inner;
} else {
using unsigned_type = typename std::make_unsigned<T>::type;
stream << +(static_cast<unsigned_type>(-(value.inner + 1)) + 1);
}
} else {
// Print for binary
if (value.inner >= 0) {
stream << decimalToBinary(value.inner);
} else {
using unsigned_type = typename std::make_unsigned<T>::type;
stream << decimalToBinary<unsigned_type>(
static_cast<unsigned_type>(-(value.inner + 1)) + 1);
}
}
return true;
}
template <typename T>
inline bool pretty_print(std::ostream& stream, const print_type<T>&) {
stream << type_name<T>();
stream << " [sizeof: " << sizeof(T) << " byte, ";
stream << "trivial: ";
if (std::is_trivial<T>::value) {
stream << "yes";
} else {
stream << "no";
}
stream << ", standard layout: ";
if (std::is_standard_layout<T>::value) {
stream << "yes";
} else {
stream << "no";
}
stream << "]";
return false;
}
template <typename Enum>
inline typename std::enable_if<std::is_enum<Enum>::value, bool>::type
pretty_print(std::ostream& stream, Enum const& value) {
using UnderlyingType = typename std::underlying_type<Enum>::type;
stream << static_cast<UnderlyingType>(value);
return true;
}
inline bool pretty_print(std::ostream& stream, const std::string& value) {
stream << '"' << value << '"';
return true;
}
#if DBG_MACRO_CXX_STANDARD >= 17
inline bool pretty_print(std::ostream& stream, const std::string_view& value) {
stream << '"' << std::string(value) << '"';
return true;
}
#endif
template <typename T1, typename T2>
inline bool pretty_print(std::ostream& stream, const std::pair<T1, T2>& value) {
stream << "{";
pretty_print(stream, value.first);
stream << ", ";
pretty_print(stream, value.second);
stream << "}";
return true;
}
#if DBG_MACRO_CXX_STANDARD >= 17
template <typename T>
inline bool pretty_print(std::ostream& stream, const std::optional<T>& value) {
if (value) {
stream << '{';
pretty_print(stream, *value);
stream << '}';
} else {
stream << "nullopt";
}
return true;
}
template <typename... Ts>
inline bool pretty_print(std::ostream& stream,
const std::variant<Ts...>& value) {
stream << "{";
std::visit([&stream](auto&& arg) { pretty_print(stream, arg); }, value);
stream << "}";
return true;
}
#endif
template <typename Container>
inline typename std::enable_if<detail::is_container<const Container&>::value,
bool>::type
pretty_print(std::ostream& stream, const Container& value) {
stream << "{";
const size_t size = detail::size(value);
const size_t n = std::min(size_t{10}, size);
size_t i = 0;
using std::begin;
using std::end;
for (auto it = begin(value); it != end(value) && i < n; ++it, ++i) {
pretty_print(stream, *it);
if (i != n - 1) {
stream << ", ";
}
}
if (size > n) {
stream << ", ...";
stream << " size:" << size;
}
stream << "}";
return true;
}
template <typename T, typename... U>
struct last {
using type = typename last<U...>::type;
};
template <typename T>
struct last<T> {
using type = T;
};
template <typename... T>
using last_t = typename last<T...>::type;
class DebugOutput {
public:
// Helper alias to avoid obscure type `const char* const*` in signature.
using expr_t = const char*;
DebugOutput(const char* filepath, int line, const char* function_name)
: m_use_colorized_output(isColorizedOutputEnabled()) {
std::string path = filepath;
const std::size_t path_length = path.length();
if (path_length > MAX_PATH_LENGTH) {
path = ".." + path.substr(path_length - MAX_PATH_LENGTH, MAX_PATH_LENGTH);
}
std::stringstream ss;
ss << ansi(ANSI_DEBUG) << "[" << path << ":" << line << " ("
<< function_name << ")] " << ansi(ANSI_RESET);
m_location = ss.str();
}
template <typename... T>
auto print(std::initializer_list<expr_t> exprs,
std::initializer_list<std::string> types,
T&&... values) -> last_t<T...> {
if (exprs.size() != sizeof...(values)) {
std::cerr
<< m_location << ansi(ANSI_WARN)
<< "The number of arguments mismatch, please check unprotected comma"
<< ansi(ANSI_RESET) << std::endl;
}
return print_impl(exprs.begin(), types.begin(), std::forward<T>(values)...);
}
private:
template <typename T>
T&& print_impl(const expr_t* expr, const std::string* type, T&& value) {
const T& ref = value;
std::stringstream stream_value;
const bool print_expr_and_type = pretty_print(stream_value, ref);
std::stringstream output;
output << m_location;
if (print_expr_and_type) {
output << ansi(ANSI_EXPRESSION) << *expr << ansi(ANSI_RESET) << " = ";
}
output << ansi(ANSI_VALUE) << stream_value.str() << ansi(ANSI_RESET);
if (print_expr_and_type) {
output << " (" << ansi(ANSI_TYPE) << *type << ansi(ANSI_RESET) << ")";
}
output << std::endl;
std::cerr << output.str();
return std::forward<T>(value);
}
template <typename T, typename... U>
auto print_impl(const expr_t* exprs,
const std::string* types,
T&& value,
U&&... rest) -> last_t<T, U...> {
print_impl(exprs, types, std::forward<T>(value));
return print_impl(exprs + 1, types + 1, std::forward<U>(rest)...);
}
const char* ansi(const char* code) const {
if (m_use_colorized_output) {
return code;
} else {
return ANSI_EMPTY;
}
}
const bool m_use_colorized_output;
std::string m_location;
static constexpr std::size_t MAX_PATH_LENGTH = 20;
static constexpr const char* const ANSI_EMPTY = "";
static constexpr const char* const ANSI_DEBUG = "\x1b[02m";
static constexpr const char* const ANSI_WARN = "\x1b[33m";
static constexpr const char* const ANSI_EXPRESSION = "\x1b[36m";
static constexpr const char* const ANSI_VALUE = "\x1b[01m";
static constexpr const char* const ANSI_TYPE = "\x1b[32m";
static constexpr const char* const ANSI_RESET = "\x1b[0m";
};
// Identity function to suppress "-Wunused-value" warnings in DBG_MACRO_DISABLE
// mode
template <typename T>
T&& identity(T&& t) {
return std::forward<T>(t);
}
template <typename T, typename... U>
auto identity(T&&, U&&... u) -> last_t<U...> {
return identity(std::forward<U>(u)...);
}
} // namespace dbg
#ifndef DBG_MACRO_DISABLE
// Force expanding argument with commas for MSVC, ref:
// https://stackoverflow.com/questions/35210637/macro-expansion-argument-with-commas
// Note that "args" should be a tuple with parentheses, such as "(e1, e2, ...)".
#define DBG_IDENTITY(x) x
#define DBG_CALL(fn, args) DBG_IDENTITY(fn args)
#define DBG_CAT_IMPL(_1, _2) _1##_2
#define DBG_CAT(_1, _2) DBG_CAT_IMPL(_1, _2)
#define DBG_16TH_IMPL(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
_14, _15, _16, ...) \
_16
#define DBG_16TH(args) DBG_CALL(DBG_16TH_IMPL, args)
#define DBG_NARG(...) \
DBG_16TH((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
// DBG_VARIADIC_CALL(fn, data, e1, e2, ...) => fn_N(data, (e1, e2, ...))
#define DBG_VARIADIC_CALL(fn, data, ...) \
DBG_CAT(fn##_, DBG_NARG(__VA_ARGS__))(data, (__VA_ARGS__))
// (e1, e2, e3, ...) => e1
#define DBG_HEAD_IMPL(_1, ...) _1
#define DBG_HEAD(args) DBG_CALL(DBG_HEAD_IMPL, args)
// (e1, e2, e3, ...) => (e2, e3, ...)
#define DBG_TAIL_IMPL(_1, ...) (__VA_ARGS__)
#define DBG_TAIL(args) DBG_CALL(DBG_TAIL_IMPL, args)
#define DBG_MAP_1(fn, args) DBG_CALL(fn, args)
#define DBG_MAP_2(fn, args) fn(DBG_HEAD(args)), DBG_MAP_1(fn, DBG_TAIL(args))
#define DBG_MAP_3(fn, args) fn(DBG_HEAD(args)), DBG_MAP_2(fn, DBG_TAIL(args))
#define DBG_MAP_4(fn, args) fn(DBG_HEAD(args)), DBG_MAP_3(fn, DBG_TAIL(args))
#define DBG_MAP_5(fn, args) fn(DBG_HEAD(args)), DBG_MAP_4(fn, DBG_TAIL(args))
#define DBG_MAP_6(fn, args) fn(DBG_HEAD(args)), DBG_MAP_5(fn, DBG_TAIL(args))
#define DBG_MAP_7(fn, args) fn(DBG_HEAD(args)), DBG_MAP_6(fn, DBG_TAIL(args))
#define DBG_MAP_8(fn, args) fn(DBG_HEAD(args)), DBG_MAP_7(fn, DBG_TAIL(args))
#define DBG_MAP_9(fn, args) fn(DBG_HEAD(args)), DBG_MAP_8(fn, DBG_TAIL(args))
#define DBG_MAP_10(fn, args) fn(DBG_HEAD(args)), DBG_MAP_9(fn, DBG_TAIL(args))
#define DBG_MAP_11(fn, args) fn(DBG_HEAD(args)), DBG_MAP_10(fn, DBG_TAIL(args))
#define DBG_MAP_12(fn, args) fn(DBG_HEAD(args)), DBG_MAP_11(fn, DBG_TAIL(args))
#define DBG_MAP_13(fn, args) fn(DBG_HEAD(args)), DBG_MAP_12(fn, DBG_TAIL(args))
#define DBG_MAP_14(fn, args) fn(DBG_HEAD(args)), DBG_MAP_13(fn, DBG_TAIL(args))
#define DBG_MAP_15(fn, args) fn(DBG_HEAD(args)), DBG_MAP_14(fn, DBG_TAIL(args))
#define DBG_MAP_16(fn, args) fn(DBG_HEAD(args)), DBG_MAP_15(fn, DBG_TAIL(args))
// DBG_MAP(fn, e1, e2, e3, ...) => fn(e1), fn(e2), fn(e3), ...
#define DBG_MAP(fn, ...) DBG_VARIADIC_CALL(DBG_MAP, fn, __VA_ARGS__)
#define DBG_STRINGIFY_IMPL(x) #x
#define DBG_STRINGIFY(x) DBG_STRINGIFY_IMPL(x)
#define DBG_TYPE_NAME(x) dbg::type_name<decltype(x)>()
#define dbg(...) \
dbg::DebugOutput(__FILE__, __LINE__, __func__) \
.print({DBG_MAP(DBG_STRINGIFY, __VA_ARGS__)}, \
{DBG_MAP(DBG_TYPE_NAME, __VA_ARGS__)}, __VA_ARGS__)
#else
#define dbg(...) dbg::identity(__VA_ARGS__)
#endif // DBG_MACRO_DISABLE
#endif // DBG_MACRO_DBG_H

156
include/nnet/derivator.h Normal file
View File

@ -0,0 +1,156 @@
#pragma once
#include "common.h"
#include "expr.h"
#include "iterator_table.h"
#include "routine.h"
#include <iostream>
#include <sstream>
#include <unordered_set>
namespace nnet {
class Formula {
public:
Expr root;
const int bfsDepth;
public:
Formula(Expr _root, int _bfsDepth) : root(_root), bfsDepth(_bfsDepth) {}
string toReadable() const;
friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
ios << expr.toReadable();
return ios;
}
bool isVariable() const { return as<VarNode>(root) != nullptr; }
};
class MultiFormulas {
public:
VecExpr roots;
const int bfsDepth;
public:
MultiFormulas(VecExpr roots, int _bfsDepth)
: roots(roots), bfsDepth(_bfsDepth) {}
// string toReadable() const;
// friend std::ostream &operator<<(std::ostream &ios, const Formula &expr) {
// ios << expr.toReadable();
// return ios;
// }
};
class Derivator {
public:
enum class LogMode { Normal, DumpFristCandiate, NoLog };
enum class PassMode { Debug, Full };
private:
list<Formula> candidates;
const int maxDepth;
int nIteratorNames = 0;
int nTensorNames = 0;
vector<vector<int>> rulesOverall;
enum class Strategy { DFS, Rule, RuleAndDFS } searchStrategy;
LogMode logMode;
PassMode passMode;
bool enableEquivalenceCheck = false;
string logFnPrefix;
const bool enableHashPruning;
int searchedMaxDepth = 0;
RoutineType targetOp = RoutineType::NoneType;
map<int, vector<Var>> substituteRules;
vector<int> cntAppliedRules;
int cntRule3 = 0;
std::unordered_set<HashType> visited;
VecExpr intermediateStates;
vector<string> ruleStates, ruleMsgs;
int cntStates = 0; // the number of intermediate states
int searchState = 0; // search state in guided search
public:
Derivator(int maxDepth = 8, bool enableHashPruning = true,
LogMode mode = LogMode::NoLog,
PassMode passMode = PassMode::Debug);
void search(Formula &origin, int depth);
void ruleBasedDFS(Formula &origin, int depth, vector<int> _rules,
map<int, vector<Var>> _substituteRules = {},
bool searchAfterRules = false);
void guidedSearch(Formula &origin, int depth);
void print();
int getNumCandidates() const { return candidates.size(); }
const auto &getCandidates() const { return candidates; }
void appendCanddiate(const Tensor &tensor, int depth);
int getSearchedMaxDepth() const { return searchedMaxDepth; };
bool stageCombination(MultiFormulas &origin, int depth);
bool checkOOB(const RangeOp &rangeOp, bool halt = true);
string newTensorName();
Var getNewVar();
Expr mergeMemboundStages(VecExpr stages);
private:
void dfs(Formula &origin, int depth);
void ruleBasedDerivate(Formula &origin, int depth);
void rule1VariableSplit(Formula &origin, int depth, Expr &rCur);
void rule2VariableMerging(Formula &origin, int depth, Expr &rCur);
void rule3StageSplit(Formula &origin, int dfsDepth, Expr &rCur);
void rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur);
bool rule4StageMerging(Formula &origin, int depth, Expr &rCur,
bool mergeStageWithCalc = false);
void rule6KenerlMatching(Formula &origin, int depth, Expr &rCur);
void rule7DLT(Formula &origin, int depth, Expr &rCur);
// Rule 8: guidedDLT
void rule8GuidedDLT(Formula &origin, int depth, Expr &rCur);
void rule9RangeMagnify(Formula &origin, int depth, Expr &rCur);
void rule90TwoStageElementWise(Formula &origin, int depth, Expr &rCur);
void rule91MergeStagesWithSum(Formula &origin, int depth, Expr &rCur);
/**
* @brief For searchState=2, wrap the RangeOp to add offset, if the boundary
* does not start from 0. Then match the inner offset RangeOp.
*/
void matchComputationKernel(Formula &origin, int depth, Expr &rcur);
/**
* @brief For searchState=3, the Formula must be a MemBound kernel?
*/
void matchMemBoundKernel(Formula &origin, int depth, Expr &rcur);
/**
* @brief Check the equivalence for exprs in intermediateStates.
*/
void checkDerivationEquivalence();
public:
void pushIntermediateState(const Expr &expr);
void pushRuleState(const string &state);
void pushRuleMsg(const string &state);
void popIntermediateState();
void popRuleState();
void popRuleMsg();
// void pushTransformInfo(const Expr &expr, const string &state,
// const string &msg);
void nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur);
RoutineType getTargetOp();
void setTargetOp(RoutineType _targetOp);
int getSearchState();
void setSearchState(int _searchState);
int getNumIntermediateStates();
void printStatistics();
void printIntermediateStates();
/**
* @brief Enable dumping the first results. Verification is enabled so it is
* slow.
*
* @param _logFnPrefix Prefix of output filename
*/
void setDumpFirstSuccess(const string &_logFnPrefix);
void setEquivalenceCheck();
PassMode getPassMode();
LogMode getLogMode();
};
} // namespace nnet

53
include/nnet/dlt.h Normal file
View File

@ -0,0 +1,53 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
namespace nnet {
// enum class DLTType { Split, Merge, Reorder };
struct DLTOperation {
// DLTType type;
virtual ~DLTOperation() {}
};
struct DLTSplit : DLTOperation {
int dim, factor;
DLTSplit(int _dim, int _factor) : dim(_dim), factor(_factor) {}
};
struct DLTMerge : DLTOperation {
int dim0, dim1;
DLTMerge(int _dim0, int _dim1) : dim0(_dim0), dim1(_dim1) {}
};
struct DLTReorder : DLTOperation {
vector<int> dims;
DLTReorder(vector<int> _dims) : dims(_dims) {}
};
class DLT {
vector<Ref<DLTOperation>> ops;
public:
/**
* @brief dim -> (dim/factor, factor)
*/
void split(int dim, int factor);
/**
* @brief Merge dim1 into dim0 -> (dim0, dim1)
*/
void merge(int dim0, int dim1);
/**
* @brief
*
* @param dims dims[new_dim]=old_dim
*/
void reorder(vector<int> dims);
optional<Expr> apply(const RangeOp &rangeOp, const Subscript &subscript,
string newTensorName);
private:
optional<pair<Expr, Expr>> splitIndex(Expr expr, int factor,
RangeOp rangeOp);
};
} // namespace nnet

416
include/nnet/expr.h Normal file
View File

@ -0,0 +1,416 @@
#pragma once
#include "common.h"
#include "ref.h"
#include <iostream>
#include <numeric>
#include <type_traits>
namespace nnet {
class ExprNode;
class VarNode;
class TensorNode;
class OperatorNode;
class RangeOpNode;
class SubscriptNode;
class BinaryOpNode;
class ConstantNode;
class FuncNode;
using Expr = Ref<ExprNode>;
using Var = Ref<VarNode>;
using Tensor = Ref<TensorNode>;
using Operator = Ref<OperatorNode>;
using RangeOp = Ref<RangeOpNode>;
using Subscript = Ref<SubscriptNode>;
using BinaryOp = Ref<BinaryOpNode>;
using Constant = Ref<ConstantNode>;
using Func = Ref<FuncNode>;
class RoutineNode;
using Routine = Ref<RoutineNode>;
enum class RoutineType {
NoneType = 100,
MatmulNodeType,
ConvNodeType,
G2bmmNodeType,
GbmmNodeType,
ElementWiseNodeType // unmatchable
};
constexpr inline int MatchableRoutineTypeCnt = 4;
constexpr inline int RoutineTypeCnt = MatchableRoutineTypeCnt + 1;
inline RoutineType idToRoutineType(int i) {
return static_cast<RoutineType>(i + 1 +
static_cast<int>(RoutineType::NoneType));
}
inline int routineTypeToId(const RoutineType &routineType) {
return static_cast<int>(routineType) -
static_cast<int>(RoutineType::NoneType) - 1;
}
using VecExpr = vector<Expr>;
// common data structure
using Iterator = Var; // RE: remove this alias
template <typename T, typename U> using PtrMap = std::map<T, U, ptr_less<T>>;
template <typename T, typename U>
// When keys are pointers, compare keys according to its value instead of
// address Specially, the name of Var are compared due to the overload of op=
// and hash.
using PtrUmap = std::unordered_map<T, U, ptr_hash<T>, ptr_equal<T>>;
template <typename T>
using PtrUset = std::unordered_set<T, ptr_hash<T>, ptr_equal<T>>;
using Appearance = PtrMap<Var, vector<pair<Tensor, int>>>;
using StrideTable =
PtrMap<Var, vector<tuple<TensorNode *, int, int>>>; // Tensor, dim, stride
// AST node opeartor
bool operator==(const Var &lhs, const string &rhs);
bool operator==(const string &lhs, const Var &rhs);
Expr operator+(const Expr &lhs, const Expr &rhs);
BinaryOp operator-(const Expr &lhs, const Expr &rhs);
BinaryOp operator*(const Expr &lhs, const Expr &rhs);
BinaryOp operator/(const Expr &lhs, const Expr &rhs);
BinaryOp operator%(const Expr &lhs, const Expr &rhs);
Expr operator+(const Expr &lhs, const int &rhs);
Expr operator+(const int &lhs, const Expr &rhs);
Expr operator-(const Expr &lhs, const int &rhs);
Expr operator-(const int &lhs, const Expr &rhs);
Expr operator*(const Expr &lhs, const int &rhs);
Expr operator*(const int &lhs, const Expr &rhs);
Expr operator%(const Expr &lhs, const int rhs);
Expr operator/(const Expr &lhs, const int rhs);
string serializeVec(vector<Expr> v);
string serializeVec(vector<Var> v);
template <typename T> inline string serializeVec(vector<T> v) {
if (v.empty())
return "[]";
return "[" +
std::accumulate(
v.begin() + 1, v.end(), to_string(v[0]),
[](const string &a, int b) { return a + ',' + to_string(b); }) +
"]";
}
// For RTTI and visitor pattern
enum class NodeType {
ConstantNodeType,
BinaryOpNodeType,
RangeOpNodeType,
SubscriptNodeType,
TensorNodeType,
VarNodeType,
FuncNodeType
};
enum class FuncType { Relu, Tanh };
#define DEFINE_GETTYPE(CLASS) \
NodeType getType() const override { return NodeType::CLASS##Type; }
class ExprNode {
public:
virtual ~ExprNode() {}
ExprNode &operator=(const ExprNode &rhs) = delete;
virtual HashType hash() const = 0; // RE: remove?
virtual string toReadable() const = 0;
friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr);
virtual NodeType getType() const = 0;
};
class VarNode : public ExprNode {
std::string name;
public:
VarNode(std::string _name) : name(_name){};
virtual ~VarNode() {}
DEFINE_GETTYPE(VarNode);
const std::string &getName() const { return name; }
HashType hash() const override { return genhash(name); };
string toReadable() const override { return name; };
bool equal(const Var &rhs) const { return name == rhs->getName(); }
bool neq(const Var &rhs) const { return !equal(rhs); }
bool less(const Var &rhs) const { return name < rhs->getName(); }
bool equal(const string &rhs) const { return name == rhs; }
bool operator==(const VarNode &rhs) const { return name == rhs.getName(); }
bool operator<(const VarNode &rhs) const { return name < rhs.getName(); }
};
enum class TensorType { Input, Weight, Intermediate };
class TensorNode : public ExprNode {
string name;
vector<int> shape, paddings;
TensorType type;
Routine source; // if NO source, then this is a input/weight tensor
public:
TensorNode(string _name, vector<int> _shape, vector<int> _paddings = {},
Routine _source = nullptr);
virtual ~TensorNode() {}
DEFINE_GETTYPE(TensorNode);
bool operator==(const string &rhs) { return name == rhs; }
friend bool operator==(const string &lhs, const TensorNode &rhs) {
return lhs == rhs.name;
}
HashType hash() const override { return genhash(name); }
string toReadable() const override;
string toOutputShape() const;
const std::string &getName() const { return name; }
std::vector<int> &getPadding() { return paddings; }
int getPadding(int i) const { return paddings[i]; }
const vector<int> &getPaddings() const { return paddings; }
void setPadding(int i, int p) { paddings[i] = p; }
const vector<int> &getShape() const { return shape; }
int getShape(int i) const { return shape[i]; }
int64_t getSize() const;
int getDims() const { return shape.size(); }
const Routine &getSource() const { return source; }
int getData(const Ref<vector<int>> &data, const vector<int> &idx);
size_t getOffset(const vector<int> &idx);
};
enum class OpType { Range, Add, Mul, Div, Mod, Sub };
const char opSymbols[] = "#+*/%-";
class OperatorNode : public ExprNode {
protected:
const OpType opType;
VecExpr subExprs;
public:
OperatorNode(OpType _opType) : opType(_opType){};
OperatorNode(OpType _opType, VecExpr _subExprs)
: opType(_opType), subExprs(_subExprs){};
int getSubExprsNum() { return subExprs.size(); };
const VecExpr &getSubExprs() { return subExprs; }
const Expr &getSubExprs(int i) const { return subExprs[i]; }
OpType getOpType() const { return opType; };
void setOperands(int i, Expr e) { subExprs[i] = e; }
};
using Range = pair<int, int>;
using VarRangePair = pair<Var, Range>;
inline int getLength(const Range &range) { return range.second - range.first; }
struct IterationType {
enum { Loop, Sum };
constexpr static int NumIterationType = 2;
};
class RangeOpNode : public OperatorNode {
public:
enum { Summand, END_POS };
constexpr static int Loop = IterationType::Loop;
constexpr static int Sum = IterationType::Sum;
private:
vector<VarRangePair> vars[IterationType::NumIterationType];
vector<int> paddings;
public:
RangeOpNode(Expr _summand) : OperatorNode(OpType::Range, {_summand}){};
RangeOpNode(const vector<VarRangePair> &_loopIters,
const vector<VarRangePair> &_sumIters, Expr _summand,
const vector<int> &paddings)
: OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters},
paddings(paddings){};
DEFINE_GETTYPE(RangeOpNode);
virtual HashType hash() const override {
nnet_unimplemented_halt();
return 0;
};
string toReadable() const override;
const Expr &getSummand() const { return subExprs[Summand]; }
const vector<VarRangePair> &getVarRanges(int _index) const {
return vars[_index];
}
const vector<VarRangePair> &getLoopVarRanges() const {
return vars[IterationType::Loop];
}
const vector<VarRangePair> &getSumVarRanges() const {
return vars[IterationType::Sum];
}
int getNumOutputDims() const;
bool hasVar(int index, Var name) const;
bool hasLoopVar(Var name) const { return hasVar(Loop, name); }
bool hasSumVar(Var name) const { return hasVar(Sum, name); }
bool hasLoopVar(string name) const {
return hasVar(Loop, make_ref<VarNode>(name));
}
bool hasSumVar(string name) const {
return hasVar(Sum, make_ref<VarNode>(name));
}
int getVarIndex(int type, string name);
void setSummand(Expr e) { subExprs[Summand] = e; }
void setLoopIterator(const vector<VarRangePair> &vecExpr) {
vars[Loop] = vecExpr;
}
void setSumIterator(const vector<VarRangePair> &vecExpr) {
vars[Sum] = vecExpr;
}
void setIterator(const vector<VarRangePair> &loop,
const vector<VarRangePair> &sum) {
setLoopIterator(loop);
setSumIterator(sum);
}
const VarRangePair &getVarRange(int _index, int i) const {
return vars[_index][i];
}
const Var &getLoopVar(int i) const { return vars[Loop][i].first; }
Range getRange(const Var &var) const;
VarRangePair getVarRange(const Var &var) const;
bool hasPaddings() const;
int getPaddings(int dim) const;
vector<int> getPaddings() const;
void setPaddings(vector<int> _paddings);
void setVarRange(int _index, int i, VarRangePair pair) {
vars[_index][i] = pair;
}
int64_t getFlops() const;
int64_t getInputSize(const RangeOp &self) const;
int64_t getOutputSize() const;
vector<int> getOutputShape() const;
// Including paddings
vector<Range> getOutputRanges() const;
};
class BinaryOpNode : public OperatorNode {
enum { LHS, RHS, END_POS };
public:
BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs)
: OperatorNode(_opType, {_lhs, _rhs}){};
virtual ~BinaryOpNode() {}
DEFINE_GETTYPE(BinaryOpNode);
virtual HashType hash() const override {
return genhash((HashType)opType,
genhash(subExprs[LHS]->hash(), subExprs[RHS]->hash()));
};
virtual string toReadable() const override;
const Expr &getLhs() const { return getSubExprs(LHS); };
const Expr &getRhs() const { return getSubExprs(RHS); };
void setLhs(Expr e) { setOperands(LHS, e); };
void setRhs(Expr e) { setOperands(RHS, e); };
// If Var/constant, use this one
optional<pair<Var, int>> getModDivParameter() const;
// If (Var+constant)/constant, use this one
pair<Expr, int> getModDivExpr() const;
bool isSwapable() const;
};
class ConstantNode : public ExprNode {
int val;
public:
ConstantNode(int _val) : val(_val){};
ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){};
virtual ~ConstantNode() {}
DEFINE_GETTYPE(ConstantNode);
int getValue() const { return val; }
virtual HashType hash() const override { return genhash(val, 6214587); };
virtual string toReadable() const override {
string ret;
ret += std::to_string(val);
return ret;
};
};
class SubscriptNode : public ExprNode {
protected:
Expr indexed;
VecExpr subExprs;
public:
SubscriptNode(Expr _indexed, vector<Expr> _subExprs) : subExprs(_subExprs) {
setObject(_indexed);
};
DEFINE_GETTYPE(SubscriptNode);
virtual HashType hash() const override {
nnet_unimplemented_continue();
return -1;
};
virtual string toReadable() const override;
size_t getDims() const { return subExprs.size(); }
const VecExpr &getIndex() const { return subExprs; }
const Expr &getIndex(size_t i) const { return subExprs[i]; }
void setIndex(size_t i, Expr e) { subExprs[i] = e; }
Expr *getObjectPtr() { return &indexed; }
Expr getObject() const { return indexed; }
void setObject(Expr e);
bool isRangeOpSubscripted() const;
bool isTensorSubscripted() const { return !isRangeOpSubscripted(); }
// Get the ranges of objects including paddings
vector<Range> getObjectRangesWithPaddings() const;
vector<Range> getObjectRangesWithoutPaddings() const;
};
class FuncNode : public ExprNode {
protected:
Subscript object;
FuncType funcType;
public:
FuncNode(Expr object, FuncType funcType) : funcType(funcType) {
setObject(object);
}
DEFINE_GETTYPE(FuncNode);
virtual HashType hash() const override {
nnet_unimplemented_continue();
return -1;
};
virtual string toReadable() const override;
const Subscript &getObject() const { return object; }
void setObject(Expr e);
FuncType getFuncType() const { return funcType; }
};
// Wrappers for type deduction
Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts);
RangeOp makeRangeOperator(const vector<VarRangePair> &_loopIters,
const vector<VarRangePair> &_sumIters, Expr _summand,
const vector<int> &paddings = {});
Tensor makeTensor(const string &name, const vector<int> &shape,
const vector<int> &paddings = {},
const Routine &source = nullptr);
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
*_ = nullptr>
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
os << ((!a) ? string("nullptr") : a->toReadable());
return os;
}
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<std::is_base_of_v<ExprNode, T>>
*_ = nullptr>
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
os << ((!a) ? string("nullptr") : a->toReadable());
return os;
}
#undef DEFINE_GETTYPE
} // namespace nnet
namespace std {
template <> struct hash<nnet::VarNode &> {
size_t operator()(const nnet::VarNode &t) const {
return std::hash<string>()(t.getName());
}
};
} // namespace std

View File

@ -0,0 +1,234 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
namespace nnet {
using PatternTensorMap = vector<Tensor>;
using PatternIterRangeMap = PtrMap<Iterator, VarRangePair>;
enum class MismatchType {
// Search required (undetermined)
MoreVar,
LessVar,
StrideMismatch,
// guided DLT (determined)
DLMismatch,
OutputDLMismatch,
OutputDimismatch
};
struct Mismatch {
MismatchType type;
int bitmap; // Row ID of IT
PtrMap<Iterator, Iterator>
mappingIter_r; // For DLT mismatch, iters are mapped
Mismatch(MismatchType _type, int _bitmap) : type(_type), bitmap(_bitmap) {}
Mismatch(MismatchType _type, int _bitmap,
PtrMap<Iterator, Iterator> _mappingIter_r)
: type(_type), bitmap(_bitmap), mappingIter_r(_mappingIter_r) {}
};
class Pattern;
class IteratorTable {
protected:
// using Appearance = map<string, vector<pair<Tensor, int>>>;
// using StrideTable = map<TensorNode *, vector<tuple<string, int, int>>>;
// // Var, dim, stride
RangeOp rangeOp;
// To real tensor
// FIXME: redundent
Appearance appearance;
vector<Tensor> tensors; // original tensor sequence
vector<Subscript> subscripts; // original subscripts sequence
StrideTable strideTable; // TODO [Refactor]: rename strideTable
PatternIterRangeMap iterToRange;
// mapping
vector<int> tensorMap; // [index for tensors] -> tensorID in pattern
PtrMap<Iterator, Iterator> iterMap; // [expr iter] -> pattern iter
// final data
vector<vector<Iterator>> posTable; // [Tensor bitmap]=[Iterator]
vector<vector<vector<Iterator>>>
iterInTensorDim; // [tensorID][dimOfTensor]=[Iterator],
// stride in each dim may be add
vector<vector<PtrMap<Iterator, int>>>
strideInDim; // [tensorID][dimOfTensor][Iterator]=stride,
// stride in each dim may be add
PtrMap<Iterator, vector<int>> strideInTensor; // [Iterator][tensorID]=stride
// final data: auxiliary data
vector<int> tensorIDMap_r;
PatternTensorMap tensorMap_r;
PatternIterRangeMap iterToRange_r;
public:
virtual ~IteratorTable() {}
IteratorTable() {}
IteratorTable(const IteratorTable &) = delete;
[[nodiscard]] bool analyzeExpr(const RangeOp &rangeOp);
// mapTensors
void buildTable(const vector<int> &_tensorMap);
void buildTableWithDefaultMap();
/**
* @brief Check whether the expression match a pattern. If not, return the
* detailed reason for guided search.
*
* @param patternIT
* @return vector<int> mismatched IT rows/tensors for guided DLT.
*/
vector<Mismatch> matchPatternIT(const Pattern &patternIT);
void matchIterators();
int getNumInputs() const { return tensors.size(); }
int getNumTensors() const { return tensors.size() + 1; }
int getNumRows() const { return 1 << getNumTensors(); }
int getNumIterators() const { return strideTable.size(); }
// vector<Tensor> tensorMap_r(
// pattern.nInputs); // [pattern tensor ID] -> real tensor
// map<string, VarRangePair> iterToRange_r; // [pattern iter] -> iter &
// range
auto getTables() const {
return tuple(posTable, iterInTensorDim, strideInTensor);
}
const auto &getStrideInDim() const { return strideInDim; }
vector<vector<Iterator>> getIterInTensorDim(int tensorID) const {
return iterInTensorDim[tensorID];
}
const vector<Iterator> &getPosTable(int bitmap) const {
return posTable[bitmap];
}
pair<PatternTensorMap, PatternIterRangeMap> getReverseMap() const;
int getStridesInTensor(Iterator iter, int tensorID) const;
vector<int> getIterDimInTensor(int tensorID, const Iterator &iter) const;
Tensor getTensor(int tensorID) const { return tensorMap_r[tensorID]; }
Subscript getSubscript(int tensorID) const {
return subscripts[tensorIDMap_r[tensorID]];
}
Range getIterRange(const Iterator &iter) const {
return rangeOp->getRange(iter);
}
/**
* @brief Check strides of each iterators and there position in tensors.
* Since many-to-many iterators matching exist, we take this procudure as a
* seperate function to deal with different iterator mapping solution.
*
* @param patternIT
* @param mappingIter_r
* @return vector<Mismatch>
*/
vector<Mismatch>
matchPatternITCheckStrides(const Pattern &patternIT,
PtrMap<Iterator, Iterator> mappingIter_r);
RangeOp getRangeOp() const;
};
struct StrideConstraint {
int tensorID;
Var v0, v1;
enum class Constraint { SAME, PROPOTIONAL } type;
};
class Pattern : public IteratorTable {
vector<StrideConstraint> strideConstraints;
public:
virtual Expr
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const = 0;
/**
* @brief Check whether all indexes only are a iterator
*
* @param tensorID
*/
bool isAllUniqueAccess(int tensorID) const;
const auto &getStrideConstraints() const { return strideConstraints; };
int calcPadding(const Tensor &tensor, int dim, Range rangeH, Range rangeR,
int offset) const;
};
class MatmulPattern : public Pattern {
public:
static const Pattern &getMatmulPattern();
static pair<Expr, pair<Tensor, Tensor>> getExpr(bool transA, bool transB,
int b, int m, int n, int k);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
class ConvPattern : public Pattern {
private:
static const Var n, c, h, w, f, r, s;
public:
static const Pattern &getPattern();
static Expr getExpr(Tensor A, Tensor K, int n, int c, int h, int w, int f,
int r, int s);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
class ConvTransPattern : public Pattern {
private:
static const Var n, c, h, w, f, r, s;
public:
static const Pattern &getPattern() = delete;
static Expr getExpr(Tensor A, Tensor K, int N, int C, int H, int W, int F,
int R, int S);
Expr
buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override {
nnet_unimplemented_halt();
return nullptr;
};
};
class Sg2bmmPattern : public Pattern {
private:
static const Var b, m, w, k;
public:
static const Pattern &getPattern();
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int K,
int W, int D);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
class LongformerGBMMPattern : public Pattern {
private:
static const Var b, m, w, n;
public:
static const Pattern &getPattern();
static pair<Expr, pair<Tensor, Tensor>> getExpr(int Batch, int M, int W,
int K, int dilation);
Expr buildExpr(const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges,
string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const override;
};
const Pattern &getPattern(RoutineType targetOp);
string getPatternName(RoutineType targetOp);
} // namespace nnet

57
include/nnet/nmutator.h Normal file
View File

@ -0,0 +1,57 @@
#pragma once
#include "core/mutator.h"
#include "nnet/expr.h"
#ifdef ABC
namespace infini {
class NMutator : public Mutator {
private:
// Suffix -N: NNet objects.
// Suffix -T: tpm objects.
// Map: NNet tensors -> tpm tensor.
std::map<std::string, Tensor> inputsNameNToTensorT;
enum class Mode { Normal, ToNaiveMembound, RuleBased } mode = Mode::Normal;
const double bandwidth = double(200) * 1024 * 1024 * 1024;
// If in RuleBased mode, use derivationRules in derivator
const std::vector<int> derivationRules;
public:
NMutator();
NMutator(const std::vector<int> &derivationRules);
~NMutator();
vector<Graph> run(const Graph &in_graph) override;
void setToNaiveMembound();
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
long long cntStates = 0;
long long cntCandidates = 0;
private:
int maxDepth = 8;
nnet::Expr opToExpression(Operator op);
void runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs);
/**
* @brief Test helper. Converting a single OP to Membound Op for
* corretness check.
*/
void runSingleOpToNaiveMembound(Graph in_graph,
std::vector<Graph> &out_graphs);
void runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs);
Graph expressionToGraph(nnet::Expr expr, Graph in_graph);
Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
double memboundTime(ssize_t cnt);
double memboundTime(const Shape &dims);
Graph transformTConv1x1(Operator op);
Graph transformTConv3x3(Operator op);
Graph transformDialtedConv(Operator op);
Graph transformConv1x1(Operator op);
Graph transformConv1xk(Operator op);
};
} // namespace infini
#endif

View File

@ -0,0 +1,38 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
namespace nnet {
class PermutationGenerator {
vector<vector<Iterator>> from, to;
vector<vector<size_t>> mapping;
public:
PermutationGenerator(vector<vector<Iterator>> _from,
vector<vector<Iterator>> _to);
bool next();
PtrMap<Iterator, Iterator> get() const;
};
template <typename T> class SubsetGenerator {
vector<T> elements;
int n, bitmap;
public:
SubsetGenerator(vector<T> elements, bool nonEmpty = 1)
: elements(elements), n(elements.size()), bitmap((nonEmpty > 0)) {
assert(n < 10);
};
bool next() { return ((++bitmap) < (1 << n) - 1); }
vector<T> get() const {
vector<T> ret;
for (int i = 0; i < n; ++i)
if (bitmap & (1 << i))
ret.emplace_back(elements[i]);
return ret;
}
};
} // namespace nnet

200
include/nnet/ref.h Normal file
View File

@ -0,0 +1,200 @@
#pragma once
#include "common.h"
#include <functional> // hash
#include <memory>
#include <type_traits>
namespace nnet {
template <typename T> struct is_ref;
/**
* Ref-counting pointer
*
* This class is thread-safe (For developers: concurrent accesses through
* different `std::shared_ptr`s to the same object is already thread-safe, while
* modifying the same `std::shared_ptr` is not. We never modify a `Ref`, so no
* locks are needed. See https://en.cppreference.com/w/cpp/memory/shared_ptr)
*/
template <class T> class Ref {
static_assert(is_ref<T>::value == false, "Ref should not be nested");
template <class U> friend class Ref;
std::shared_ptr<T> ptr_;
private:
public:
typedef T Object;
Ref() = default;
// Ref(std::nullptr_t) : Ref() {}
constexpr Ref(nullptr_t) noexcept : Ref() {}
Ref(const Ref &) = default;
Ref(Ref &&) = default;
Ref(std::shared_ptr<T> &&ptr) : ptr_(std::move(ptr)) {}
// Ref(const std::shared_ptr<T> &ptr) : ptr_(ptr) {}
// /// NO NOT USE THIS CONSTRUCTOR IN PUBLIC
// /// It is public because Pybind11 needs it
// Ref(T *ptr) : ptr_(ptr) {}
/**
* Shared with any compatible references
*/
template <class U,
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
Ref(const Ref<U> &other) : ptr_(std::static_pointer_cast<T>(other.ptr_)) {}
template <class U,
typename std::enable_if_t<std::is_base_of_v<T, U>> * = nullptr>
Ref &operator=(const Ref<U> &other) {
ptr_ = std::static_pointer_cast<T>(other.ptr_);
return *this;
}
Ref &operator=(const Ref &) = default;
Ref &operator=(Ref &&) = default;
template <class U> Ref<U> as() const {
Ref<U> ret;
ret.ptr_ = std::dynamic_pointer_cast<U>(ptr_);
return ret;
}
bool isValid() const { return ptr_ != nullptr; }
T &operator*() const {
nnet_assert(isValid(), "Empty pointer.");
return *ptr_;
}
T *operator->() const {
nnet_assert(isValid(), "Empty pointer.");
return ptr_.get();
}
T *get() const {
nnet_assert(isValid(), "Empty pointer.");
return ptr_.get();
}
friend inline bool operator==(const Ref &lhs, nullptr_t) {
return !lhs.isValid();
}
friend inline bool operator!=(const Ref &lhs, nullptr_t) {
return !(lhs == nullptr);
}
explicit operator bool() const { return ptr_ != nullptr; }
bool operator!() { return ptr_ == nullptr; }
void swap(Ref &__b) noexcept { ptr_.swap(__b.ptr_); }
};
template <class T, class U,
typename std::enable_if_t<std::is_base_of_v<U, T>> * = nullptr>
Ref<T> as(const Ref<U> &ref) {
return ref.template as<T>();
}
template <typename T, typename... Params> Ref<T> make_ref(Params &&...params) {
return Ref(make_shared<T>(std::forward<Params>(params)...));
}
// Comparator for Ref
template <typename T> struct is_ref : std::false_type {};
template <typename T> struct is_ref<Ref<T>> : std::true_type {};
template <class Tuple, std::size_t index = 0, bool address_based>
typename std::enable_if_t<not is_ref<std::tuple_element_t<index, Tuple>>::value,
bool>
__ref_less(const Tuple &lhs, const Tuple &rhs) {
if constexpr (index >=
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
return std::get<index>(lhs) < std::get<index>(rhs);
else {
if (std::get<index>(lhs) != std::get<index>(rhs))
return std::get<index>(lhs) < std::get<index>(rhs);
else
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
}
}
template <class Tuple, std::size_t index = 0, bool address_based>
typename std::enable_if_t<is_ref<std::tuple_element_t<index, Tuple>>::value and
not address_based,
bool>
__ref_less(const Tuple &lhs, const Tuple &rhs) {
if constexpr (index >=
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
return std::get<index>(lhs)->less(std::get<index>(rhs));
else {
if (std::get<index>(lhs)->neq(std::get<index>(rhs)))
return std::get<index>(lhs)->less(std::get<index>(rhs));
else
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
}
}
template <class Tuple, std::size_t index = 0, bool address_based>
typename std::enable_if_t<
is_ref<std::tuple_element_t<index, Tuple>>::value and address_based, bool>
__ref_less(const Tuple &lhs, const Tuple &rhs) {
if constexpr (index >=
std::tuple_size<std::remove_reference_t<Tuple>>::value - 1)
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
else {
if (std::get<index>(lhs).get() != std::get<index>(rhs).get())
return std::get<index>(lhs).get() < std::get<index>(rhs).get();
else
return __ref_less<Tuple, index + 1, address_based>(lhs, rhs);
}
}
template <class Tuple> bool ref_addr_less(const Tuple &lhs, const Tuple &rhs) {
return __ref_less<Tuple, 0, true>(lhs, rhs);
}
template <class Tuple> bool ref_value_less(const Tuple &lhs, const Tuple &rhs) {
return __ref_less<Tuple, 0, false>(lhs, rhs);
}
template <class Tuple> class RefAddrLess {
public:
bool operator()(const Tuple &a, const Tuple &b) const {
return ref_addr_less(a, b);
}
};
template <class Tuple> class RefValueLess {
public:
bool operator()(const Tuple &a, const Tuple &b) const {
return ref_value_less(a, b);
}
};
// make_ref_from_tuple
template <typename _Tp, typename _Tuple, size_t... _Idx>
constexpr Ref<_Tp> make_ref_from_tuple_impl(_Tuple &&__t,
std::index_sequence<_Idx...>) {
return make_ref<_Tp>(std::get<_Idx>(std::forward<_Tuple>(__t))...);
}
template <typename _Tp, typename _Tuple>
constexpr Ref<_Tp> make_ref_from_tuple(_Tuple &&__t) {
return make_ref_from_tuple_impl<_Tp>(
std::forward<_Tuple>(__t),
std::make_index_sequence<std::tuple_size_v<std::decay_t<_Tuple>>>{});
}
} // namespace nnet
// namespace std {
// template <class T> struct hash<ir::Ref<T>> {
// hash<T *> hash_;
// size_t operator()(const ir::Ref<T> &ref) const { return hash_(ref.get());
// }
// };
// } // namespace nnet

158
include/nnet/routine.h Normal file
View File

@ -0,0 +1,158 @@
#pragma once
#include "common.h"
#include "expr.h"
#include <iostream>
#include <sstream>
namespace nnet {
class RoutineNode;
class MatmulNode;
class ElementWiseNode;
using Routine = Ref<RoutineNode>;
using Matmul = Ref<MatmulNode>;
using ElementWise = Ref<ElementWiseNode>;
#define DEFINE_GETTYPE(CLASS) \
RoutineType getType() const override { return RoutineType::CLASS##Type; }
class RoutineNode {
protected:
Expr expr;
vector<Tensor> inputs;
public:
RoutineNode(Expr _expr, const vector<Tensor> &_inputs);
virtual string toReadable() const = 0;
const Expr &getExpr() const { return expr; }
const vector<Tensor> &getInputs() const { return inputs; }
virtual RoutineType getType() const = 0;
};
using MatmulArgs = tuple<int, // b
int, // m
int, // n
int, // k
bool, // transa
bool>; // transb
class MatmulNode : public RoutineNode {
int b, m, n, k;
bool transa, transb;
public:
MatmulNode(Expr _source, Tensor A, Tensor B, int _b, int _m, int _n, int _k,
bool _transa, bool _transb)
: RoutineNode(_source, {A, B}), b(_b), m(_m), n(_n), k(_k),
transa(_transa), transb(_transb) {}
DEFINE_GETTYPE(MatmulNode);
string toReadable() const override;
friend bool operator==(const MatmulNode &lhs, const MatmulNode &rhs);
MatmulArgs getArgs() { return tuple(b, m, n, k, transa, transb); }
};
using ConvArgs = tuple<int, // ph
int, // pw
int, // sh
int, // sw
int, // dh
int>; // dw
class ConvNode : public RoutineNode {
int ph, pw;
int sh, sw;
int dh, dw;
public:
ConvNode(Expr _source, Tensor A, Tensor K, int _ph, int _pw, int _sh = 1,
int _sw = 1, int _dh = 1, int _dw = 1)
: RoutineNode(_source, {A, K}), ph(_ph), pw(_pw), sh(_sh), sw(_sw),
dh(_dh), dw(_dw) {}
DEFINE_GETTYPE(ConvNode);
string toReadable() const override;
vector<int> getShape() const;
friend bool operator==(const ConvNode &lhs, const ConvNode &rhs);
ConvArgs getArgs() const;
};
class ElementWiseNode : public RoutineNode {
vector<int> outputShape;
public:
// _outputShape is redundent, but expr is still missing for DLT.
ElementWiseNode(Expr _source, vector<Tensor> _inputs,
vector<int> _outputShape)
: RoutineNode(_source, _inputs), outputShape(_outputShape) {}
DEFINE_GETTYPE(ElementWiseNode);
string toReadable() const override;
/**
* @brief Get the Estimated Time of mem bound OP.
*
* @return double Time in ms.
*/
double getEstimatedTime() const;
const vector<int> &getOutputShape() const { return outputShape; }
};
using G2bmmArgs = tuple<int, // b
int, // m
int, // w
int, // k
int>; // dilation
class G2bmmNode : public RoutineNode {
int b, m, w, k;
public:
G2bmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int k,
int d = 1)
: RoutineNode(source, {A, B}), b(b), m(m), w(w), k(k) {
assert(d == 1);
}
DEFINE_GETTYPE(G2bmmNode);
vector<int> getShape() const;
string toReadable() const override;
G2bmmArgs getArgs() const;
};
using GbmmArgs = tuple<int, // b
int, // m
int, // w
int, // n
int>; // dilation
class GbmmNode : public RoutineNode {
int b, m, w, n;
public:
GbmmNode(Expr source, Tensor A, Tensor B, int b, int m, int w, int n,
int d = 1)
: RoutineNode(source, {A, B}), b(b), m(m), w(w), n(n) {
assert(d == 1);
}
DEFINE_GETTYPE(GbmmNode);
vector<int> getShape() const;
string toReadable() const override;
GbmmArgs getArgs() const;
};
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
std::ostream &operator<<(std::ostream &os, const shared_ptr<T> &a) {
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
return os;
}
// Pretty output for dbg with shared_ptr
template <typename T, typename std::enable_if_t<
std::is_base_of_v<RoutineNode, T>> *_ = nullptr>
std::ostream &operator<<(std::ostream &os, const Ref<T> &a) {
os << ((!a) ? string("Null shared_ptr") : a->toReadable());
return os;
}
} // namespace nnet

28
include/nnet/test.h Normal file
View File

@ -0,0 +1,28 @@
#pragma once
#include "common.h"
#include "derivator.h"
// clang-format off
#define CAT(A, B) A##B
#define SELECT(NAME, NUM) CAT(NAME##_, NUM)
#define GET_COUNT( _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, COUNT, ... ) COUNT
#define VA_SIZE( ... ) GET_COUNT( __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 )
#define VA_SELECT( NAME, ... ) SELECT( NAME, VA_SIZE(__VA_ARGS__) )(__VA_ARGS__)
#define _DEFVAR_1(name) auto name = make_ref<VarNode>(#name);
#define _DEFVAR_2(name, ...) _DEFVAR_1(name); _DEFVAR_1(__VA_ARGS__)
#define _DEFVAR_3(name, ...) _DEFVAR_1(name); _DEFVAR_2(__VA_ARGS__)
#define _DEFVAR_4(name, ...) _DEFVAR_1(name); _DEFVAR_3(__VA_ARGS__)
#define _DEFVAR_5(name, ...) _DEFVAR_1(name); _DEFVAR_4(__VA_ARGS__)
#define _DEFVAR_6(name, ...) _DEFVAR_1(name); _DEFVAR_5(__VA_ARGS__)
#define _DEFVAR_7(name, ...) _DEFVAR_1(name); _DEFVAR_6(__VA_ARGS__)
#define _DEFVAR_8(name, ...) _DEFVAR_1(name); _DEFVAR_7(__VA_ARGS__)
#define _DEFVAR_9(name, ...) _DEFVAR_1(name); _DEFVAR_8(__VA_ARGS__)
#define DEFINE_VAR(...) VA_SELECT(_DEFVAR, __VA_ARGS__)
// clang-format on
namespace nnet {
int matchExprResult(Derivator &derivator, string fn);
bool checkExprLogSame(string fnPrefix, int start, int end);
bool checkExprsEquvivalence(VecExpr exprs);
} // namespace nnet

128
include/nnet/visitor.h Normal file
View File

@ -0,0 +1,128 @@
#pragma once
#include "common.h"
#include "derivator.h"
#include "expr.h"
#include "routine.h"
#include <iostream>
#include <unordered_map>
namespace nnet {
template <typename FType> class Functor;
template <typename R, typename... Args> class Functor<R(Args...)> {
protected:
int verbose;
// FIXME: scope should be protected
public:
Functor(int _verobse = 0) : verbose(_verobse) {}
virtual ~Functor() = default;
#define DISPATCH(CLASS) \
case NodeType::CLASS##Type: \
return this->visit_(as<CLASS>(c), std::forward<Args>(args)...); \
break
#define FUNCTOR_DEFAULT \
{ return visitDefault(c, std::forward<Args>(args)...); }
virtual R dispatch(const Expr &c, Args... args) {
switch (c->getType()) {
DISPATCH(ConstantNode);
DISPATCH(BinaryOpNode);
DISPATCH(RangeOpNode);
DISPATCH(SubscriptNode);
DISPATCH(TensorNode);
DISPATCH(VarNode);
DISPATCH(FuncNode);
default:
nnet_assert(0, "Unknown type");
return R();
}
}
virtual R visit_(const Constant &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const BinaryOp &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const RangeOp &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Subscript &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Var &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT;
virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT;
virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) {
dbg(*c);
nnet_assert(0, "Reach unimplemented visit function.");
return R();
};
[[deprecated("Define explicit methods for public access.")]] R
operator()(const Expr &e, Args... args) {
return dispatch(e, std::forward<Args>(args)...);
}
#undef FUNCTOR_DEFAULT
#undef DISPATCH
};
class Mutator : public Functor<Expr()> {
public:
Mutator(int _verobse = 0) : Functor(_verobse) {}
Expr visit_(const Constant &c) override;
Expr visit_(const BinaryOp &c) override;
Expr visit_(const RangeOp &c) override;
Expr visit_(const Subscript &c) override;
Expr visit_(const Var &c) override;
Expr visit_(const Tensor &c) override;
Expr visit_(const Func &c) override;
};
// template <typename... Args>
// class SingleStageVisitor : public Functor<void, Args...> {
// public:
// SingleStageVisitor(int _verobse = 0) : Functor<R, Args...>(_verobse) {}
// // R visit(const Constant &c) override ;
// R visit_(const BinaryOp &c) override {
// if (verbose)
// dbg(*c);
// this->dispatch(c->getLhs());
// this->dispatch(c->getRhs());
// }
// R visit_(const RangeOp &c) override {
// if (verbose)
// dbg(*c);
// this->dispatch(ret->getSummand());
// // NOT visit iterators and its ranges
// }
// R visit_(const Subscript &c) override {
// if (verbose)
// dbg(*c);
// this->dispatch(ret->getObject());
// for (size_t i = 0; i < ret->getDims(); ++i)
// this->dispatch(ret->getIndex(i));
// }
// // R visit(const Var &c) override;
// // R visit(const Tensor &c) override;
// };
// } // namespace nnet
// #include "nnet/Visitor/ReplaceVariable.h"
// #include "nnet/Visitor/StrideVisitor.h"
// namespace nnet {
class ExprTreeVisitor : public Functor<void(void)> {
private:
bool inBinary, inRange, inSub, inTensor;
public:
ExprTreeVisitor(bool _inBinary = 1, bool _inRange = 1, bool _inSub = 1,
bool _inTensor = 1, int _verobse = 0)
: Functor(_verobse), inBinary(_inBinary), inRange(_inRange),
inSub(_inSub), inTensor(_inTensor) {}
void visit_(const Constant &c) override;
void visit_(const BinaryOp &c) override;
void visit_(const RangeOp &c) override;
void visit_(const Subscript &c) override;
void visit_(const Var &c) override;
void visit_(const Tensor &c) override;
void visit_(const Func &c) override;
};
} // namespace nnet

View File

@ -0,0 +1,25 @@
#include "nnet/Pass/MatchComputationKernel.h"
#include "nnet/Visitor/PatternMatcher.h"
namespace nnet {
// RE: is this duplicate with Rule6KenerlMatching?
void MatchComputationKernel::transform(Formula &origin, int depth, Expr &rCur) {
nnet_assert(derivator.getSearchState() == 2, __LINE__);
auto cur = as<RangeOpNode>(rCur);
// Build wrapper stages for enforce axis starts from 0
PatternMatcher patternMatcher(derivator, cur);
cur = patternMatcher.getOffsetCur();
auto matches = patternMatcher.matchWithPattern(
cur, getPattern(derivator.getTargetOp()));
matches = patternMatcher.applyWrapper(matches);
for (auto newCur : matches) {
derivator.setSearchState(3);
nextStep(origin, depth, rCur, newCur);
derivator.setSearchState(2);
}
}
} // namespace nnet

View File

@ -0,0 +1,23 @@
#include "nnet/Pass/MatchMemBoundKernel.h"
#include "nnet/Visitor/InputVisitor.h"
namespace nnet {
void MatchMemBoundKernel::transform(Formula &origin, int depth, Expr &rCur) {
// FIXME: Whether the Formula is a Membound OP should be checked.
nnet_assert(derivator.getSearchState() == 3, __LINE__);
nnet_assert(origin.root.get() == rCur.get(),
"Only match the entire formula as a Membound Op");
auto rangeOp = as<RangeOpNode>(origin.root);
const auto &inputs = InputVisitor().getInputs(rangeOp);
auto source =
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
auto tensor =
makeTensor(newTensorName(), rangeOp->getOutputShape(), {}, source);
// The original code directly appends candidate. But it seems should be done
// by the search.
// appendCanddiate(as<TensorNode>(tensor), depth);
nextStep(origin, depth, rCur, tensor);
}
} // namespace nnet

58
src/nnet/Pass/Pass.cc Normal file
View File

@ -0,0 +1,58 @@
#include "nnet/Pass/Pass.h"
#include "nnet/Visitor/CloneMutator.h"
namespace nnet {
Pass::Pass(Derivator &derivator, const string &passName)
: derivator(derivator), passName(passName),
enableLogging(derivator.getLogMode() != Derivator::LogMode::NoLog),
enableDebug(false) {}
Pass::~Pass() = default;
void Pass::setEnableLogging(bool value) { enableLogging = value; }
void Pass::setEnableDebug(bool value) { enableDebug = value; }
void Pass::run(Formula &origin, int dfsDepth, Expr &rCur) {
initialize(origin, rCur);
transform(origin, dfsDepth, rCur);
finalize();
}
void Pass::initialize(Formula &origin, const Expr &rCur) {}
void Pass::finalize() {}
Var Pass::getNewVar() { return derivator.getNewVar(); }
string Pass::newTensorName() { return derivator.newTensorName(); }
void Pass::nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur,
const string &ruleMsg) {
// push rule action description
if (enableLogging) {
rCur.swap(newCur);
derivator.pushIntermediateState(origin.root);
rCur.swap(newCur);
derivator.pushRuleState(passName);
derivator.pushRuleMsg(ruleMsg);
}
if (enableDebug) {
// In debug mode, do not recur but save the transformed state
transformations.emplace_back(CloneMutator().clone(newCur));
} else
derivator.nextStep(origin, depth, rCur, newCur);
// pop rule action description
if (enableLogging) {
derivator.popIntermediateState();
derivator.popRuleState();
derivator.popRuleMsg();
}
}
const VecExpr &Pass::getTransformations() { return transformations; }
} // namespace nnet

View File

@ -0,0 +1,134 @@
#include "nnet/Pass/Rule1VariableSplit.h"
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
void Rule1VariableSplit::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
vector<Replace> replaces = getSplitableVar(cur);
// for (const auto &replace : replaces)
// dbg(replace.oldIters, replace.newIters, replace.psis,
// replace.newVarRanges);
for (const auto &replace : replaces) {
auto replacedSummand = replaceIters(cur->getSummand(), replace);
if (!replacedSummand) {
// TODO: if a real getMergableExprs is implemented, this case should
// be an error. Since the expr should appear in the AST.
dbg("Warning: No replacment happens.");
continue;
}
auto inner =
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
// build the outerRange{innerRange}[indexForInner] to do DLT
Expr nextCur = nullptr;
if (replace.iteratorType == IterationType::Loop) {
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
nextCur = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
} else
nextCur = inner;
string msg = "====== END rule1 VariableSplit: ";
dbg(msg, replace.oldIters, replace.newIters, replace.phis,
replace.psis);
msg = replace.toReadable();
nextStep(origin, depth, rCur, nextCur, msg);
}
}
vector<Replace> Rule1VariableSplit::getSplitableVar(const RangeOp &rangeOp) {
vector<Replace> ret;
// Split strategy
vector<int> SumFactors, LoopFactors;
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
SumFactors = {3};
LoopFactors = {4};
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
SumFactors = {2, 3};
// LoopFactors = {3, 4};
LoopFactors = {4};
} else
nnet_unimplemented_halt();
// Split Sum variable
for (const int k : SumFactors) {
for (const auto &[var, range] : rangeOp->getSumVarRanges()) {
int len = range.second - range.first;
auto p1 = getNewVar(); // p1=i/k
auto p2 = getNewVar(); // p2=i%k
if (len > 10 || len <= k || len % k != 0)
continue;
Range range1, range2;
if (range.first < 0) {
nnet_unimplemented_halt();
// FIXME: this must be ERROR
range1.first = range.first / k;
range1.second = range1.first + len / k;
range2.first = -k / 2;
range2.second = range2.first + k;
} else if (range.first == 0) {
range1.first = 0;
range1.second = len / k;
range2.first = 0;
range2.second = k;
} else {
nnet_unimplemented_continue();
continue;
}
Replace replace{.iteratorType = IterationType::Sum,
.oldIters = {var},
.newIters = {p1, p2},
.phis = {},
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
.newVarRanges = {{p1, range1}, {p2, range2}}};
ret.emplace_back(replace);
}
}
for (const int k : LoopFactors) {
// Split Loop variable
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
const int len = range.second - range.first;
// Debug HACK for dilated SG2BMM
if (derivator.getPassMode() == Derivator::PassMode::Debug &&
!(var->getName() == "m" && len % k == 0))
continue;
// Illeagel conditions
if (range.second - range.first <= k ||
(range.second - range.first) % k != 0)
continue;
// Unsupport conditions
if (range.first != 0)
continue;
auto p1 = getNewVar(); // p1=i/k
auto p2 = getNewVar(); // p2=i%k
Range range1(0, len / k);
Range range2(0, k);
nnet_assert(range1.second > 0 && range2.second > 0,
"Empty loop dim");
Replace replace{.iteratorType = IterationType::Loop,
.oldIters = {var},
.newIters = {p1, p2},
.phis = {var / 4, var % 4},
.psis = {make_ref<ConstantNode>(k) * p1 + p2},
.newVarRanges = {{p1, range1}, {p2, range2}}};
ret.emplace_back(replace);
}
}
return ret;
}
Expr Rule1VariableSplit::replaceIters(Expr cur, const Replace &replace) {
// TODO [feature]: support multiple replacements in one mutator
if (replace.oldIters.size() != 1) {
nnet_unimplemented_continue();
return nullptr;
}
auto replaceMutator =
ReplaceVariable(replace.oldIters.at(0), replace.psis.at(0));
auto ret = replaceMutator(cur);
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,186 @@
#include "nnet/Pass/Rule2VariableMerging.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
namespace nnet {
void Rule2VariableMerging::transform(Formula &origin, int depth, Expr &rCur) {
// Extract r and s
auto cur = as<RangeOpNode>(rCur);
vector<Replace> replaces = getMergableReplaces(cur, depth);
// dbg("Start rule2VariableMerging", depth, mergableExprs, *cur);
for (const auto &replace : replaces) {
if (replace.iteratorType != IterationType::Loop) {
nnet_unimplemented_continue();
continue;
}
// replace vars in summand
auto replacedSummand = ReplaceKit::replaceMultipleExprs(
cur->getSummand(), replace.oldIters, replace.psis, true);
// replace var in rangeOp
auto inner =
ReplaceKit::replaceRangeOpIterator(cur, replace, replacedSummand);
// If OOB happens, this transformation is skipped
if (CheckOOBVisitor().checkRangeOp(inner))
continue;
// build the outerRange{innerRange}[indexForInner] to do DLT
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
// next searching step
string msg = replace.toReadable();
nextStep(origin, depth, rCur, outer, msg);
}
}
vector<Replace> Rule2VariableMerging::getMergableReplaces(RangeOp rangeOp,
int depth) {
vector<Replace> ret;
IteratorTable exprIT;
if (!exprIT.analyzeExpr(rangeOp)) {
nnet_unimplemented_continue();
return ret;
}
exprIT.buildTableWithDefaultMap();
const auto &strideInAllDim = exprIT.getStrideInDim();
set<pair<Iterator, Iterator>, RefValueLess<pair<Iterator, Iterator>>>
checkedIterPairs{};
// strideInAllDim: [tensorID][dimOfTensor][Iterator]=stride
for (size_t tensorID = 0; tensorID < strideInAllDim.size(); ++tensorID) {
const auto &strideInDimsOfATensor = strideInAllDim[tensorID];
for (const PtrMap<Iterator, int> &strideInADim :
strideInDimsOfATensor) {
for (const auto &it1 : strideInADim) {
for (const auto &it2 : strideInADim) {
// Backdoor for rule-based search
if (substituteRules.count(depth)) {
if (substituteRules[depth].at(0)->neq(it1.first))
continue;
if (substituteRules[depth].at(1)->neq(it2.first))
continue;
}
if (!(it1.first->equal(it2.first) &&
it1.second == it2.second) &&
rangeOp->hasLoopVar(it1.first) &&
rangeOp->hasLoopVar(it2.first)) {
// 2 iters -> 2 iters
if (auto opt = getReplaceMappingTwoLoopIters(rangeOp,
it1, it2))
ret.emplace_back(*opt);
// 2 iters -> 1 iter
const auto iterPair = pair(it1.first, it2.first);
if (!checkedIterPairs.count(iterPair)) {
checkedIterPairs.insert(iterPair);
if (auto opt = getReplaceMergingTwoLoopIters(
rangeOp, it1, it2, exprIT, tensorID))
ret.emplace_back(*opt);
}
}
}
}
}
}
return ret;
}
optional<Replace> Rule2VariableMerging::getReplaceMergingTwoLoopIters(
const RangeOp &rangeOp, pair<Iterator, int> pairA,
pair<Iterator, int> pairB, const IteratorTable &exprIT, int tensorID) {
// 1*A + sb*B -> C
// A=C%sb, B=C/sb
// ax+by->z, a=1 or -1
// For a>0 and b>0 : x=z%b, y=z/b
auto x = pairA.first, y = pairB.first;
int a = pairA.second, b = pairB.second;
if (abs(a) != 1 || abs(a) * abs(b) <= 0)
return {};
if (a < 0 && b > 0) { // The only unhandled case
nnet_unimplemented_continue();
return {};
}
// negative substitution happens only if can be totally merged. So if the
// variable appears in another index, skip it.
if (a < 0 || b < 0) {
if (exprIT.getNumInputs() > 1) {
if (exprIT.getStridesInTensor(x, 1 - tensorID) != 0)
return {};
if (exprIT.getStridesInTensor(y, 1 - tensorID) != 0)
return {};
}
}
Range rangeX = rangeOp->getVarRange(x).second,
rangeY = rangeOp->getVarRange(y).second;
if (rangeX.first != 0 || rangeY.first != 0)
return {};
int lenX = rangeX.second - rangeX.first;
if (abs(b) != lenX)
return {};
auto z = getNewVar();
Range rangeExpr{0, 1}; // 1 is the open interval compensation
auto calcRangeExpr = [&rangeExpr](int stride, const Range &r) {
if (stride > 0) {
rangeExpr.first += stride * r.first;
rangeExpr.second += stride * (r.second - 1);
} else {
rangeExpr.first += stride * (r.second - 1);
rangeExpr.second += stride * r.first;
}
};
calcRangeExpr(a, rangeX);
calcRangeExpr(b, rangeY);
// build the phi/psi for index transformation
// phi: j_x=(i_x...), psi: i_x=(j_x...)
auto ret = optional<Replace>();
ret.emplace();
ret->iteratorType = IterationType::Loop;
ret->newIters = {z};
ret->oldIters = {x, y};
ret->phis = {a * x + b * y - rangeExpr.first};
// For b < 0, the psis are not an equavalent replace. Since it must be
// simplified (z/b and z%b will be merged), the only important thing is
// their strides should be mergable. To merge the strides, an extra minus
// are introduced if their stride is negative.
ret->psis = {a * (z % b) + a * rangeExpr.first, (b > 0 ? 1 : -1) * (z / b)};
ret->newVarRanges = {{z, {0, rangeExpr.second - rangeExpr.first}}};
return ret;
}
optional<Replace>
Rule2VariableMerging::getReplaceMappingTwoLoopIters(const RangeOp &rangeOp,
pair<Iterator, int> pairA,
pair<Iterator, int> pairB) {
// the first iterator is replaced, the second remains
auto i1 = pairA.first, i2 = pairB.first;
int sa = pairA.second, sb = pairB.second;
// TODO: can be relaxed to sb|sb
if (sa != 1 || sb == 0)
return {};
if (sb < 0) {
nnet_unimplemented_continue();
return {};
}
Range rangeA = rangeOp->getVarRange(i1).second;
Range rangeB = rangeOp->getVarRange(i2).second;
auto j1 = getNewVar(), j2 = getNewVar();
Range rangeJ1, rangeJ2 = rangeB;
assert(pairA.second == 1);
rangeJ1.first = rangeA.first + rangeB.first * sb;
rangeJ1.second = rangeA.second + (rangeB.second - 1) * sb;
// build the phi/psi for index transformation
// phi: j_x=(i_x...), psi: i_x=(j_x...)
auto ret = optional<Replace>();
ret.emplace();
ret->iteratorType = IterationType::Loop;
ret->newIters = {j1, j2};
ret->oldIters = {i1, i2};
ret->newVarRanges = {{j1, rangeJ1}, {j2, rangeJ2}};
ret->phis = {sa * i1 + sb * i2, i2};
ret->psis = {j1 - (sb / sa) * j2, j2};
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,82 @@
#include "nnet/Pass/Rule3StageSplit.h"
#include "nnet/permutation.h"
namespace nnet {
void Rule3StageSplit::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
vector<vector<Iterator>> splitSchemes = getSplitSummationIters(cur);
for (const auto &varSplit : splitSchemes) {
bool isSplittable = false;
for (const auto &splitted : varSplit)
if (cur->hasSumVar(splitted))
isSplittable = true;
assert(isSplittable);
const vector<VarRangePair> loopVars = cur->getLoopVarRanges(),
sumVars = cur->getSumVarRanges();
// move iterators from Sigma to Loop
vector<VarRangePair> innerLoopVars, innerSumVars, outerSumVars;
VecExpr indexForInner;
for (const auto &kv : sumVars) {
bool isSplitted = false;
for (const auto &iter : varSplit)
if (iter == kv.first->getName())
isSplitted = true;
if (isSplitted) {
innerLoopVars.emplace_back(kv);
outerSumVars.emplace_back(kv);
} else
innerSumVars.emplace_back(kv);
}
innerLoopVars.insert(innerLoopVars.end(), loopVars.begin(),
loopVars.end());
for (const auto &[var, _] : innerLoopVars)
indexForInner.emplace_back(var);
// if no sum iterator, the stage is redundant
assert(!innerSumVars.empty());
auto inner =
makeRangeOperator(innerLoopVars, innerSumVars, cur->getSummand());
auto subscriptedInner = make_ref<SubscriptNode>(inner, indexForInner);
auto outer = makeRangeOperator(cur->getLoopVarRanges(), outerSumVars,
subscriptedInner);
outer->setPaddings(cur->getPaddings());
// next searching step
string msg = "Separate sum iters: " + serializeVec(varSplit);
nextStep(origin, depth, rCur, outer, msg);
}
}
vector<vector<Iterator>>
Rule3StageSplit::getSplitSummationIters(RangeOp rangeOp) {
// set<string> varSplit = {"r", "s", "i3", "i13"};
vector<vector<Iterator>> ret;
// Rule-based Hint
// vector<vector<Iterator>> heuristics = {{"r", "s"}, {"i3", "i13"}};
// for (const auto &iterSet : heuristics) {
// bool notExist = false;
// for (const auto &iter : iterSet)
// if (!rangeOp->hasSumVar(iter))
// notExist = true;
// if (!notExist)
// ret.emplace_back(iterSet);
// }
// if (!rulesOverall.empty())
// return ret;
vector<Iterator> sumIters;
for (const auto &[iter, range] : rangeOp->getSumVarRanges())
sumIters.emplace_back(iter);
if (sumIters.size() <= 1)
return ret;
SubsetGenerator gen(sumIters);
do {
ret.emplace_back(gen.get());
} while (gen.next());
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,91 @@
#include "nnet/Pass/Rule4StageMerging.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
void Rule4StageMerging::transform(Formula &origin, int depth, Expr &rCur) {
success = rule4StageMerging(origin, depth, rCur, mergeStageWithCalc);
}
bool Rule4StageMerging::isSuccessful() { return success; }
void Rule4StageMerging::setMergeStageWithCalc(bool value) {
mergeStageWithCalc = value;
}
bool Rule4StageMerging::rule4StageMerging(Formula &origin, int depth,
Expr &rCur, bool mergeStageWithCalc) {
auto rangeOp0 = as<RangeOpNode>(rCur);
const Subscript &sub0 = as<SubscriptNode>(rangeOp0->getSummand());
if (!sub0)
return false;
const auto &rangeOp1 = as<RangeOpNode>(sub0->getObject());
if (!rangeOp1)
return false;
const auto &sub1 = as<SubscriptNode>(rangeOp1->getSummand());
if (!sub1)
return false;
// merge stage with calculation only when mergeStageWithCalc=true
if (!mergeStageWithCalc && !rangeOp1->getSumVarRanges().empty())
return false;
// Only propogate paddings in perfect nested dimension
if (rangeOp1->hasPaddings()) {
auto oldTensor = as<TensorNode>(sub1->getObject());
if (!oldTensor) {
nnet_unimplemented_continue();
return 0;
}
}
// repalce variables: iters of rangeOp1 repalced by indexes of sub0
map<string, pair<Expr, Expr>> varMapping;
assert(sub0->getDims() == rangeOp1->getLoopVarRanges().size());
for (size_t i = 0; i < sub0->getDims(); ++i) {
varMapping[rangeOp1->getLoopVar(i)->getName()] =
pair(rangeOp1->getLoopVar(i), sub0->getIndex(i));
}
ReplaceVariable replaceVariable{varMapping};
auto merged = make_ref<RangeOpNode>(*rangeOp0);
merged->setSummand(replaceVariable(sub1));
// a naive approach to propogate paddings
if (rangeOp1->hasPaddings()) {
auto oldTensor = as<TensorNode>(sub1->getObject());
auto newTensor = make_ref<TensorNode>(*oldTensor);
for (int i = 0; i < rangeOp1->getNumOutputDims(); ++i) {
if (rangeOp1->getPaddings(i) == 0)
continue;
auto loopVar = rangeOp1->getLoopVar(i);
// FIXME: in fact this var should not appear in other index as well,
// which may result in OOB
bool findSingleVarAsIndex = false;
for (size_t subIndexID = 0; subIndexID < sub1->getDims();
++subIndexID) {
auto index = sub1->getIndex(subIndexID);
if (auto indexVar = as<VarNode>(index);
indexVar && (indexVar->equal(loopVar))) {
newTensor->setPadding(subIndexID,
newTensor->getPadding(subIndexID) +
rangeOp1->getPaddings(i));
findSingleVarAsIndex = true;
}
}
if (!findSingleVarAsIndex) {
nnet_unimplemented_continue();
return false;
}
}
merged = as<RangeOpNode>(
ReplaceNodeMutator().replace(merged, oldTensor.get(), newTensor));
assert(merged != nullptr);
}
// Merge inner stage sums
if (!rangeOp1->getSumVarRanges().empty())
merged->setSumIterator(rangeOp1->getSumVarRanges());
// next searching step
// if mergeStageWithCalc, depth counts for invocation in rule-based search
nextStep(origin, (mergeStageWithCalc) ? depth : depth - 1, rCur, merged);
return true;
}
} // namespace nnet

View File

@ -0,0 +1,72 @@
#include "nnet/Pass/Rule5RangeRelaxation.h"
#include "nnet/Visitor/RangeRelaxFunctor.h"
namespace nnet {
void Rule5RangeRelaxation::transform(Formula &origin, int depth, Expr &rCur) {
rule5RangeRelaxation(origin, depth, rCur);
}
Expr Rule5RangeRelaxation::rule5RangeRelaxation(Formula &origin, int depth,
Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
if (cur->hasPaddings()) {
// string msg = "====== END rule5RangeRelaxation: Paddings exist \n";
// dbg(msg);
return nullptr;
}
// Infer meaningful calculation range
RangeRelaxFunctor rangeRexlaxtionFunctor{cur};
RangeMap rangeMap = rangeRexlaxtionFunctor(cur);
auto relaxedCur = make_ref<RangeOpNode>(*cur);
bool isRelaxed = false;
vector<int> paddings;
// check whether narrow the calculation range
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
const auto &[iter, iterRange] =
cur->getVarRange(IterationType::Loop, i);
if (auto it = rangeMap.find(iter); it != rangeMap.end()) {
// intersection of validRange and iterRange is necessary computation
// TODO: it is redundant with RangeRelaxFunctor::intersectRangeMaps.
// An independent Range class might be necessary.
const Range &validRange = it->second;
Range relaxedRange{max(iterRange.first, validRange.first),
min(iterRange.second, validRange.second)};
if (relaxedRange != iterRange) {
isRelaxed = true;
relaxedCur->setVarRange(IterationType::Loop, i,
{iter, relaxedRange});
paddings.emplace_back(
max(relaxedRange.first - iterRange.first,
iterRange.second - relaxedRange.second));
} else
paddings.emplace_back(0);
} else
paddings.emplace_back(0);
}
relaxedCur->setPaddings(paddings);
if (!isRelaxed) {
// string msg = "====== END rule5RangeRelaxation: Relaxation not
// found\n"; dbg(msg);
return nullptr;
}
// next searching step
string detailedMsg;
for (size_t i = 0; i < cur->getLoopVarRanges().size(); ++i) {
const auto &[v, a] = cur->getVarRange(IterationType::Loop, i);
const auto &[_, b] = relaxedCur->getVarRange(IterationType::Loop, i);
if (a != b) {
detailedMsg += v->getName();
detailedMsg +=
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
detailedMsg +=
" (" + to_string(b.first) + "," + to_string(b.second) + "),";
}
}
nextStep(origin, depth, rCur, relaxedCur, detailedMsg);
return relaxedCur;
}
} // namespace nnet

View File

@ -0,0 +1,57 @@
#include "nnet/Pass/Rule6KenerlMatching.h"
#include "nnet/Visitor/InputVisitor.h"
#include "nnet/Visitor/PatternMatcher.h"
namespace nnet {
void Rule6KenerlMatching::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
// Build wrapper stages for enforce axis starts from 0
PatternMatcher patternMatcher(derivator, cur);
cur = patternMatcher.getOffsetCur();
// Match matchable routines
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
auto targetOp = idToRoutineType(i);
// During guided search, only check the target OP
if (derivator.getTargetOp() != RoutineType::NoneType &&
derivator.getTargetOp() != targetOp)
continue;
auto replaces =
patternMatcher.matchWithPattern(cur, getPattern(targetOp));
replaces = patternMatcher.applyWrapper(replaces);
for (auto newCur : replaces)
nextStep(origin, depth, rCur, newCur);
}
{ // Match element-wise OP
auto replaces = matchElementWise(cur);
if (!replaces.empty())
dbg(rCur);
for (auto newCur : replaces)
nextStep(origin, depth, rCur, newCur);
}
}
VecExpr Rule6KenerlMatching::matchElementWise(const RangeOp &rangeOp) {
// If the stage is compute bound, then do not convert it.
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
int64_t inputSize = rangeOp->getInputSize(rangeOp);
if (double(flops) / (inputSize + outputSize) > 3)
return {};
vector<int> newShape;
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
if (range.first != 0) {
nnet_unimplemented_continue();
return {};
}
newShape.emplace_back(range.second - range.first);
}
const auto &inputs = InputVisitor().getInputs(rangeOp);
auto source =
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
return {newTensor};
}
} // namespace nnet

78
src/nnet/Pass/Rule7DLT.cc Normal file
View File

@ -0,0 +1,78 @@
#include "nnet/Pass/Rule7DLT.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
#include "nnet/dlt.h"
namespace nnet {
void Rule7DLT::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
if (!cur)
return;
auto op = as<BinaryOpNode>(cur->getSummand());
if (!op)
return;
auto subs = {op->getLhs(), op->getRhs()};
for (auto subExpr : subs) {
auto sub = as<SubscriptNode>(subExpr);
if (!sub)
continue;
auto tensor = as<TensorNode>(sub->getObject());
if (!tensor)
continue;
// // HACK for G2BMM
// if (tensor->getDims() != 3)
// continue;
for (const auto factor : getFactors()) {
for (int targetDim = 0; targetDim < tensor->getDims();
++targetDim) {
if (tensor->getShape(targetDim) % factor)
continue;
// Debug hint for G2BMM
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
if (tensor->getShape(targetDim) != 10000)
continue;
assert(targetDim == 1);
}
DLT dlt;
dlt.split(targetDim, factor);
vector<int> newOrder(tensor->getDims() + 1);
for (int i = 0; i < tensor->getDims() + 1; ++i)
newOrder[i] = i;
newOrder[targetDim]++;
newOrder[targetDim + 1]--;
dlt.reorder(newOrder);
dlt.merge(targetDim, targetDim + 1);
if (auto opt = dlt.apply(cur, sub, newTensorName())) {
Expr newSummand = ReplaceNodeMutator().replace(
cur->getSummand(), sub.get(), *opt);
auto newCur = buildDLTSingleRangeOp(cur, newSummand);
// next searching step
string msg = "====== END rule7DLT\n";
dbg(msg);
nextStep(origin, depth, rCur, newCur);
}
}
}
}
}
Expr Rule7DLT::buildDLTSingleRangeOp(const RangeOp &original,
const Expr &newSummand) {
auto rangeOp = make_ref<RangeOpNode>(*original);
rangeOp->setSummand(newSummand);
return rangeOp;
}
vector<int> Rule7DLT::getFactors() {
if (derivator.getPassMode() == Derivator::PassMode::Debug) {
return {4};
} else if (derivator.getPassMode() == Derivator::PassMode::Full) {
return {3, 4};
} else {
nnet_unimplemented_halt();
return {};
}
}
} // namespace nnet

View File

@ -0,0 +1,317 @@
#include "nnet/Pass/Rule8GuidedDLT.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
namespace nnet {
static int bitCount(unsigned int n) {
int count = 0;
while (n != 0) {
n = n & (n - 1);
count++;
}
return count;
}
static int bitPosition(unsigned int n) {
assert(bitCount(n) == 1);
int ret = 0;
for (n >>= 1; n; n >>= 1)
++ret;
return ret;
}
void Rule8GuidedDLT::transform(Formula &origin, int depth, Expr &rCur) {
guidedDLT(origin, depth, rCur);
}
VecExpr Rule8GuidedDLT::guidedDLT(Formula &origin, int depth, Expr &rCur,
bool debug) {
string detailedMsg;
VecExpr ret;
auto cur = as<RangeOpNode>(rCur);
// check cur satisfies T1[A]*T2[B]
if (!statisfyGuidedDLT(cur))
return ret;
IteratorTable exprIT;
if (!exprIT.analyzeExpr(cur))
return ret;
exprIT.buildTableWithDefaultMap();
bool setTargetOpHere = false;
for (int i = 0; i < MatchableRoutineTypeCnt; ++i) {
// if not correctly unset this variable
assert(setTargetOpHere == false);
// If the guide direction is set
if (derivator.getTargetOp() != RoutineType::NoneType &&
idToRoutineType(i) != derivator.getTargetOp())
continue;
// Warning: no continue befor unset the targetOp
if (derivator.getTargetOp() == RoutineType::NoneType) {
setTargetOpHere = true;
derivator.setTargetOp(idToRoutineType(i));
}
const Pattern &pattern = getPattern(derivator.getTargetOp());
auto mismatches = exprIT.matchPatternIT(pattern);
// Pruning less possible results
// std::cout << "mismatches= " << mismatches.size()
// << "; setTargetOpHere: " << setTargetOpHere << "; ";
// std::cout << "TargetOp = " <<
// static_cast<int>(derivator.getTargetOp())
// << "; mismatches : ";
// for (const auto i : mismatches)
// std::cout << static_cast<int>(i.type) << " ";
// std::cout << endl;
if (mismatches.size() == 0) {
derivator.setSearchState(2);
nextStep(origin, depth, rCur, rCur);
derivator.setSearchState(1);
}
if (mismatches.size() > 0 && mismatches.size() <= 2) {
for (const auto &mismatch : mismatches) {
Expr newCur;
if (mismatch.type == MismatchType::MoreVar) {
newCur = guidedDLTMoreVar2(cur, mismatch, exprIT, pattern);
detailedMsg += "guidedDLTMoreVar2 ";
} else if (mismatch.type == MismatchType::DLMismatch ||
mismatch.type == MismatchType::OutputDLMismatch) {
if (mismatches.size() > 1) {
nnet_unimplemented_continue();
break;
}
newCur =
guidedDLTDLMismatch(cur, mismatch, exprIT, pattern);
detailedMsg += "guidedDLTDLMismatch ";
}
// std::cout << "newCur= "
// << ((newCur == nullptr) ? "Nullptr"
// : newCur->toReadable())
// << endl;
if (!newCur)
continue;
if (debug)
ret.emplace_back(newCur);
// next searching step
detailedMsg = "Toward " +
getPatternName(derivator.getTargetOp()) + ". " +
detailedMsg;
nextStep(origin, depth, rCur, newCur, detailedMsg);
}
}
// Unset targetOp
if (setTargetOpHere) {
derivator.setTargetOp(RoutineType::NoneType);
setTargetOpHere = false;
}
}
return ret;
}
Expr Rule8GuidedDLT::guidedDLTDLMismatch(
const RangeOp &cur, const Mismatch &mismatch,
[[maybe_unused]] const IteratorTable &exprIT, const Pattern &pattern) {
assert(mismatch.type == MismatchType::DLMismatch ||
mismatch.type == MismatchType::OutputDLMismatch);
// Currently only deal with ouput DLT
if (mismatch.bitmap != pattern.getNumInputs()) {
nnet_unimplemented_continue();
return nullptr;
}
vector<VarRangePair> newVarRanges;
for (const auto &[var, _] : pattern.getRangeOp()->getLoopVarRanges()) {
const auto &iterInExpr = mismatch.mappingIter_r.at(var);
newVarRanges.emplace_back(cur->getVarRange(iterInExpr));
}
auto inner = make_ref<RangeOpNode>(*cur);
inner->setLoopIterator(newVarRanges);
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, {});
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
return outer;
}
bool Rule8GuidedDLT::statisfyGuidedDLT(RangeOp cur) const {
auto mul = as<BinaryOpNode>(cur->getSummand());
if (!mul)
return false;
if (mul->getOpType() != OpType::Mul)
return false;
return as<SubscriptNode>(mul->getLhs()) && as<SubscriptNode>(mul->getRhs());
}
Expr Rule8GuidedDLT::guidedDLTMoreVar2(const RangeOp &cur,
const Mismatch &mismatch,
const IteratorTable &exprIT,
const Pattern &pattern) {
int bitmap = mismatch.bitmap;
const auto &mergedItersDefaultOrder = exprIT.getPosTable(bitmap);
// Assure vars only appear in one input tensor
int bitmapOfInputs = bitmap & ((1 << exprIT.getNumInputs()) - 1);
if (bitCount(bitmapOfInputs) > 1)
return nullptr;
if (pattern.getPosTable(bitmap).size() != 1) {
nnet_unimplemented_continue();
return nullptr;
}
if (mergedItersDefaultOrder.size() < 1)
return nullptr;
int tensorID = bitPosition(bitmapOfInputs);
if (!checkElementsHaveOnlyOneAccessIteratorSet(exprIT, tensorID))
return nullptr;
vector<Var> oldVars; // i_1, ...
vector<Var> newVars; // j_1, ...
VecExpr psis; // i_1=\psi_1(j_1, ...)
VecExpr phis; // j_1=\phi_1(i_1, ...), not necessary for Sum iter
vector<VarRangePair> newVarRanges;
auto originalTensor = exprIT.getTensor(tensorID);
auto originalSub = exprIT.getSubscript(tensorID);
vector<bool> mergedDims(originalTensor->getDims());
// Heuristic: merge iters according to their appearance positions
std::multimap<int, Var> sortedMergedIters;
for (const auto &iter : mergedItersDefaultOrder) {
vector<int> dims = exprIT.getIterDimInTensor(tensorID, iter);
assert(dims.size() == 1);
sortedMergedIters.emplace(dims[0], iter);
}
vector<Var> mergedIters; // decides the order of fused dims
for (const auto &[_, v] : sortedMergedIters)
mergedIters.emplace_back(v);
// Add the merged iterators
const auto newVar = getNewVar();
newVars.emplace_back(newVar);
int newRange = 1;
for (const auto &iter : mergedIters) {
oldVars.emplace_back(iter);
auto range = cur->getRange(iter);
newRange *= (range.second - range.first);
// if (range.first == 0)
// nnet_unimplemented_halt();
}
newVarRanges.emplace_back(newVar, Range{0, newRange});
// Add psis for each old iterator
int remainingRange = newRange;
Expr phi = nullptr;
for (const auto &iter : mergedIters) {
auto oldVar = iter;
auto range = cur->getRange(iter);
int len = (range.second - range.first);
remainingRange /= len;
Expr psi = newVar;
if (remainingRange > 1)
psi = psi / remainingRange;
if (newRange > remainingRange * len)
psi = psi % len;
int start = cur->getRange(iter).first;
if (start != 0)
psi = psi + start;
psis.emplace_back(psi);
phi = phi + remainingRange * (oldVar - start);
}
Replace replace{.iteratorType = IterationType::Loop,
.oldIters = oldVars,
.newIters = newVars,
.phis = VecExpr{phi},
.psis = psis,
.newVarRanges = newVarRanges};
// HACK: decide the rebuild data shape order
// TODO: get a partial iter mapping and permutate them?
vector<Var> tensorDimAxes{newVars};
vector<int> newShape;
for (const auto &[var, range] : newVarRanges)
newShape.emplace_back(range.second - range.first);
for (int row = 0; row < exprIT.getNumRows(); ++row) {
// Deal with other dimensions of the current tensor
if (row == bitmap || ((row & (1 << tensorID)) == 0))
continue;
using StrideIter = tuple<int, int, Iterator>;
vector<StrideIter> strideIters;
for (size_t i = 0; i < exprIT.getPosTable(row).size(); ++i) {
const auto &iter = exprIT.getPosTable(row)[i];
const Range range = cur->getRange(iter);
const int len = range.second - range.first;
// HACK Sort according to original stride. (keep original order)
strideIters.emplace_back(-exprIT.getStridesInTensor(iter, tensorID),
len, iter);
// // HACK for conv
// if (iter == "n")
// strideIters.emplace_back(2, len, iter);
// else if (iter == "c")
// strideIters.emplace_back(1, len, iter);
// else
// strideIters.emplace_back(0, len, iter);
}
// HACK: Assure the order of iterators
std::sort(strideIters.begin(), strideIters.end(),
ref_value_less<StrideIter>);
for (const auto &[_, len, oldIter] : strideIters) {
const auto &oldVar = oldIter;
tensorDimAxes.emplace_back(oldVar);
newShape.emplace_back(len);
}
}
// build DLT source
const auto sourceExpr =
buildGuidedDLTSource(originalSub, replace, tensorDimAxes, newShape);
const auto sourceRoutine = make_ref<ElementWiseNode>(
sourceExpr, vector<Tensor>{originalTensor}, newShape);
// build stage connections
const auto newTensor =
makeTensor(newTensorName(), newShape, {}, sourceRoutine);
const auto &newSub = makeSubscript(
newTensor, VecExpr(tensorDimAxes.begin(), tensorDimAxes.end()));
// TODO [1124]: get variable mapping and reorder L according to it
// dbg(cur, originalSub, newSub, newVarRanges, replace.toReadable(),
// tensorDimAxes, newShape);
// Replace the entire subscript(A[xxxxx,xxx]) in the summand
Expr newSummand = ReplaceNodeMutator().replace(cur->getSummand(),
originalSub.get(), newSub);
auto inner = ReplaceKit::replaceRangeOpIterator(cur, replace, newSummand);
auto subscriptedInner =
ReplaceKit::buildSubscirptForLoopVarReplace(inner, replace);
auto outer = ReplaceKit::buildDLTOuterRangeOp(cur, subscriptedInner);
return outer;
}
bool Rule8GuidedDLT::checkElementsHaveOnlyOneAccessIteratorSet(
const IteratorTable &exprIT, int tensorID) {
const auto &strideInDim = exprIT.getStrideInDim();
for (const auto &strideForOneDim : strideInDim[tensorID]) {
vector<pair<int, int>> strideLengthPairs;
for (const auto &[iter, s] : strideForOneDim) {
const auto &range = exprIT.getRangeOp()->getRange(iter);
strideLengthPairs.emplace_back(s, range.second - range.first);
}
std::sort(strideLengthPairs.begin(), strideLengthPairs.end());
for (size_t i = 0; i < strideLengthPairs.size() - 1; ++i) {
const auto &[stride, length] = strideLengthPairs[i];
if (stride * length > strideLengthPairs[i + 1].first)
return false;
}
}
return true;
}
Expr Rule8GuidedDLT::buildGuidedDLTSource(const Subscript &originalSub,
Replace replace,
vector<Var> tensorDimAxes,
vector<int> newShape) {
Expr newSub = ReplaceKit::replaceMultipleExprs(
originalSub, replace.oldIters, replace.psis, true);
vector<VarRangePair> loopVarRangePairs;
for (size_t i = 0; i < tensorDimAxes.size(); ++i)
loopVarRangePairs.emplace_back(tensorDimAxes[i], pair(0, newShape[i]));
return makeRangeOperator(loopVarRangePairs, {}, newSub);
}
} // namespace nnet

View File

@ -0,0 +1,54 @@
#include "nnet/Pass/Rule90TwoStageElementWise.h"
#include "nnet/Visitor/InputVisitor.h"
namespace nnet {
void Rule90TwoStageElementWise::transform(Formula &origin, int depth,
Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
{ // Match element-wise OP
auto replaces = matchTwoStageElementWise(cur);
// if (!replaces.empty())
// dbg(rCur);
// dbg(replaces);
for (auto newCur : replaces)
nextStep(origin, depth, rCur, newCur);
}
}
VecExpr
Rule90TwoStageElementWise::matchTwoStageElementWise(const RangeOp &rangeOp) {
// If the stage is compute bound, then do not convert it.
int64_t flops = rangeOp->getFlops(), outputSize = rangeOp->getOutputSize();
int64_t inputSize = rangeOp->getInputSize(rangeOp);
if (double(flops) / (inputSize + outputSize) > 3)
return {};
auto outerSub = as<SubscriptNode>(rangeOp->getSummand());
if (!outerSub)
return {};
auto innerRangeOp = as<RangeOpNode>(outerSub->getObject());
if (!innerRangeOp)
return {};
auto innerSub = as<SubscriptNode>(innerRangeOp->getSummand());
if (!innerSub)
return {};
auto innerTensor = as<TensorNode>(innerSub->getObject());
if (!innerTensor)
return {};
vector<int> newShape;
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
if (range.first != 0) {
nnet_unimplemented_continue();
return {};
}
newShape.emplace_back(range.second - range.first);
}
const auto &inputs = InputVisitor().getInputs(rangeOp);
auto source =
make_ref<ElementWiseNode>(rangeOp, inputs, rangeOp->getOutputShape());
auto newTensor = makeTensor(newTensorName(), newShape, {}, source);
return {newTensor};
}
} // namespace nnet

View File

@ -0,0 +1,11 @@
#include "nnet/Pass/Rule91MergeStagesWithSum.h"
#include "nnet/Pass/Rule4StageMerging.h"
namespace nnet {
void Rule91MergeStagesWithSum::transform(Formula &origin, int depth,
Expr &rCur) {
Rule4StageMerging(derivator).rule4StageMerging(origin, depth, rCur, true);
}
} // namespace nnet

View File

@ -0,0 +1,45 @@
#include "nnet/Pass/Rule9RangeMagnify.h"
#include "nnet/Visitor/RangeMagnifyVisitor.h"
namespace nnet {
void Rule9RangeMagnify::transform(Formula &origin, int depth, Expr &rCur) {
auto cur = as<RangeOpNode>(rCur);
if (cur->hasPaddings()) {
// string msg = "====== END rule9RangeMagnify: Paddings exist \n";
// dbg(msg);
return;
}
// HACK for conv5x5
vector<VarRangePair> newSumVarRanges;
for (const auto &[var, range] : cur->getSumVarRanges()) {
if (range.first == 0 && range.second == 5) {
newSumVarRanges.emplace_back(
var, Range{range.first, (range.second + 2) / 3 * 3});
} else
newSumVarRanges.emplace_back(var, range);
}
if (newSumVarRanges.empty())
return;
auto magnifiedCur = RangeMagnifyVisitor().magnify(cur, newSumVarRanges);
if (!magnifiedCur)
return;
// next searching step
string msg = "relax iterating ranges ";
for (size_t i = 0; i < cur->getSumVarRanges().size(); ++i) {
const auto &[v1, a] = cur->getVarRange(IterationType::Sum, i);
const auto &[v2, b] = magnifiedCur->getVarRange(IterationType::Sum, i);
assert(v1->getName() == v2->getName());
if (a != b) {
msg += v1->getName();
msg +=
" (" + to_string(a.first) + "," + to_string(a.second) + ") to";
msg += " (" + to_string(b.first) + "," + to_string(b.second) + "),";
}
}
nextStep(origin, depth, rCur, magnifiedCur, msg);
return;
}
} // namespace nnet

View File

@ -0,0 +1,165 @@
#include "nnet/Visitor/AsTVMVisitor.h"
namespace nnet {
std::string AsTVMVisitor::visit_(const Constant &c) {
return std::to_string(c->getValue());
}
std::string AsTVMVisitor::visit_(const BinaryOp &c) {
switch (c->getOpType()) {
case OpType::Add:
return "(" + dispatch(c->getLhs()) + " + " + dispatch(c->getRhs()) +
")";
case OpType::Sub:
return "(" + dispatch(c->getLhs()) + " - " + dispatch(c->getRhs()) +
")";
case OpType::Mul:
return "(" + dispatch(c->getLhs()) + " * " + dispatch(c->getRhs()) +
")";
case OpType::Div:
return "(" + dispatch(c->getLhs()) + " // " + dispatch(c->getRhs()) +
")";
case OpType::Mod:
return "(" + dispatch(c->getLhs()) + " % " + dispatch(c->getRhs()) +
")";
default:
assert(false);
}
}
std::string AsTVMVisitor::visit_(const Func &c) {
switch (c->getFuncType()) {
case FuncType::Relu:
// TODO: Deduce the dtype
return "te.max(" + dispatch(c->getObject()) +
", tvm.tir.const(0, 'float32'))";
case FuncType::Tanh:
return "te.tanh(" + dispatch(c->getObject()) + ")";
default:
assert(false);
}
}
std::string AsTVMVisitor::visit_(const RangeOp &c) {
auto outerStage = curStage;
curStage = nStage++;
std::string stmt;
std::string stageName = "s" + std::to_string(curStage);
std::vector<std::string> reduceVars;
for (auto &&[var, range] : c->getSumVarRanges()) {
std::string varName = stageName + "_" + var->getName();
stmt += varName + " = " + "te.reduce_axis((" +
std::to_string(range.first) + ", " +
std::to_string(range.second) + "), name=\"" + varName + "\")\n";
reduceVars.emplace_back(varName);
pythonVars.emplace_back(varName);
}
std::vector<int> shape;
stmt += stageName + " = te.compute((";
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
auto &&[var, range] = c->getLoopVarRanges()[i];
std::string varName = stageName + "_" + var->getName();
offset[varName] = -range.first + c->getPaddings(i);
auto len = range.second - range.first + 2 * c->getPaddings(i);
stmt += std::to_string(len) + ", ";
shape.emplace_back(len);
}
stmt += "), lambda ";
bool first = true;
for (auto &&[var, range] : c->getLoopVarRanges()) {
std::string varName = stageName + "_" + var->getName();
stmt += (first ? "" : ", ") + varName;
first = false;
}
std::string summand = dispatch(c->getSummand());
if (!reduceVars.empty()) {
summand = "te.sum(" + summand + ", axis=(";
for (auto &&var : reduceVars) {
summand += var + ", ";
}
summand += "))";
}
if (c->hasPaddings()) {
std::string guard = "tir.if_then_else(tir.all(";
bool first = true;
for (size_t i = 0, n = c->getLoopVarRanges().size(); i < n; i++) {
auto &&[var, range] = c->getLoopVarRanges()[i];
std::string varName = stageName + "_" + var->getName();
if (auto pad = c->getPaddings(i); pad > 0) {
guard += (first ? "" : ", ") + varName +
" >= " + std::to_string(range.first) + ", " + varName +
" < " + std::to_string(range.second);
first = false;
}
}
// TODO: Deduce the dtype
guard += "), " + summand + ", tvm.tir.const(0.0, \"float32\"))";
summand = guard;
}
stmt += ": " + summand + ")";
stmts += stmt + "\n";
pythonVars.emplace_back(stageName);
output = stageName;
outputShape = std::move(shape);
curStage = outerStage;
return stageName;
}
std::string AsTVMVisitor::visit_(const Subscript &c) {
std::string str = dispatch(c->getObject()) + "[";
for (size_t i = 0, n = c->getIndex().size(); i < n; i++) {
const auto &idx = c->getIndex()[i];
str += (i == 0 ? "" : ", ") + dispatch(idx);
if (c->getObject()->getType() == NodeType::RangeOpNodeType) {
auto rangeOp = as<RangeOpNode>(c->getObject());
str += " - " +
std::to_string(rangeOp->getLoopVarRanges()[i].second.first -
rangeOp->getPaddings(i));
}
}
str += "]";
return str;
}
std::string AsTVMVisitor::visit_(const Var &c) {
std::string stageName = "s" + std::to_string(curStage);
std::string varName = stageName + "_" + c->getName();
if (offset.count(varName)) {
return "(" + varName + " - " + std::to_string(offset.at(varName)) + ")";
} else {
return varName;
}
}
std::string AsTVMVisitor::visit_(const Tensor &c) {
pythonVars.emplace_back(c->getName());
inputs.emplace_back(c->getName());
inputShapes.emplace_back(c->getShape());
std::string stmt = c->getName() + " = te.placeholder((";
for (auto &&dim : c->getShape()) {
stmt += std::to_string(dim) + ", ";
}
stmt += "), name='" + c->getName() + "')";
stmts += stmt + "\n";
return c->getName();
}
std::string AsTVMVisitor::getStmts() const {
std::string ret;
// Workaround because closure capturing does not work in an `exec`
// https://stackoverflow.com/questions/2749655/why-are-closures-broken-within-exec
ret += "global ";
bool first = true;
for (auto &&var : pythonVars) {
ret += (first ? "" : ", ") + var;
first = false;
}
ret += "\n";
ret += stmts;
ret += "ret = [" + output;
for (auto &&input : inputs) {
ret += ", " + input;
}
ret += "]\n";
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,35 @@
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
void CheckOOBVisitor::visit_(const Subscript &c) {
const auto &objectRanges = c->getObjectRangesWithPaddings();
for (size_t dim = 0; dim < c->getDims(); ++dim) {
SimplifyExprVisitor simplifier;
auto optional = simplifier.getExprRange(c->getIndex(dim), rangeOp);
if (!optional.has_value())
continue;
const Range &exprRange = *optional;
if (exprRange.first < objectRanges[dim].first ||
exprRange.second > objectRanges[dim].second) {
// dbg("OOB detected!", c, dim, exprRange, objectRanges[dim]);
// std::cout << "OOB detected! " << c->toReadable() << ", dim=" <<
// dim
// << ", Range=(" << exprRange.first << ", "
// << exprRange.second << "), objRange=("
// << objectRanges[dim].first << ", "
// << objectRanges[dim].second << ")." << std::endl;
detect = true;
}
}
}
bool CheckOOBVisitor::checkRangeOp(const RangeOp &_rangeOp) {
detect = false;
rangeOp = _rangeOp;
dispatch(rangeOp);
return detect;
}
} // namespace nnet

View File

@ -0,0 +1,9 @@
#include "nnet/Visitor/CloneMutator.h"
namespace nnet {
Expr CloneMutator::visit_(const Constant &c) { return c; }
Expr CloneMutator::visit_(const Var &c) { return c; }
Expr CloneMutator::visit_(const Tensor &c) { return c; }
} // namespace nnet

View File

@ -0,0 +1,34 @@
#include "nnet/Visitor/CompareMultiFormulasVisitor.h"
namespace nnet {
bool CompareMultiFormulasVisitor::compare(const VecExpr &roots) {
if (roots.empty())
return false;
vector<RangeOp> rangeOps;
for (const auto &root : roots) {
if (auto rangeOp = as<RangeOpNode>(root))
rangeOps.emplace_back(rangeOp);
else
return false;
}
const auto pattern = rangeOps[0];
for (auto rangeOp : rangeOps) {
if (pattern->getNumOutputDims() != rangeOp->getNumOutputDims()) {
return false;
}
for (int i = 0; i < pattern->getNumOutputDims(); ++i)
if (pattern->getVarRange(0, i).second !=
rangeOp->getVarRange(0, i).second) {
return false;
}
for (size_t i = 0; i < pattern->getSumVarRanges().size(); ++i)
if (pattern->getVarRange(1, i).second !=
rangeOp->getVarRange(1, i).second) {
return false;
}
}
return true;
}
} // namespace nnet

View File

@ -0,0 +1,38 @@
#include "nnet/Visitor/CountRoutineVisitor.h"
namespace nnet {
void CountRoutineVisitor::visit_(const Tensor &c) {
if (auto routine = c->getSource(); routine) {
cnts[routineTypeToId(routine->getType())]++;
}
ExprTreeVisitor::visit_(c);
}
vector<int> CountRoutineVisitor::count(const Expr &root) {
cnts = vector<int>(RoutineTypeCnt, 0);
dispatch(root);
return cnts;
}
bool CountRoutineVisitor::match(const Expr &root, int nMatmul, int nConv,
int nElement, int nSg2bmm,
int nLongformerGBMM) {
auto opCount = count(root);
bool ret = true;
if (opCount[routineTypeToId(RoutineType::MatmulNodeType)] != nMatmul)
ret = false;
if (opCount[routineTypeToId(RoutineType::ConvNodeType)] != nConv)
ret = false;
if (opCount[routineTypeToId(RoutineType::ElementWiseNodeType)] != nElement)
ret = false;
if (opCount.at(routineTypeToId(RoutineType::G2bmmNodeType)) != nSg2bmm)
ret = false;
if (!ret) {
auto target =
vector<int>{nMatmul, nConv, nSg2bmm, nLongformerGBMM, nElement};
}
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,58 @@
#include "nnet/Visitor/FullPrinterVisitor.h"
namespace nnet {
void FullPrinterVisitor::visit_(const Tensor &c) {
q.emplace_back(c->getName(), c->getSource(), c);
}
string FullPrinterVisitor::print(const Expr &root) {
q.clear();
std::ostringstream oss;
dispatch(root);
oss << "==> ROOT\n" << root->toReadable() << "\n";
for (size_t i = 0; i < q.size(); ++i) {
const auto &[name, routine, tensor] = q[i];
oss << "==> " << name << " : ";
if (routine) {
oss << routine->toReadable() << "\n";
if (routine->getExpr()) {
oss << routine->getExpr()->toReadable() << "\n";
} else
oss << "[INFO] Source is nullptr \n";
if (!routine->getInputs().empty()) {
for (const auto &tensor : routine->getInputs())
q.emplace_back(tensor->getName(), tensor->getSource(),
tensor);
} else if (routine->getExpr())
dispatch(routine->getExpr());
} else
oss << "Input Tensor " << tensor->toOutputShape() << "\n";
}
return oss.str();
}
const vector<tuple<string, Routine, Tensor>> &
FullPrinterVisitor::traverse(const Expr &root) {
q.clear();
dispatch(root);
for (size_t i = 0; i < q.size(); ++i) {
const auto &[name, routine, tensor] = q[i];
if (routine) {
// Matmul after DLT do not modify expression, so inputs has a higher
// priority. Some OPs such as DLT have not implement source. Then
// use inputs
if (!routine->getInputs().empty()) {
for (const auto &tensor : routine->getInputs())
dispatch(tensor);
} else if (routine->getExpr()) {
dispatch(routine->getExpr());
} else {
assert(false);
}
}
}
return q;
}
} // namespace nnet

View File

@ -0,0 +1,9 @@
#include "nnet/Visitor/GetTensorsVisitor.h"
namespace nnet {
void GetTensorsVisitor::visit_(const Tensor &c) {
tensors.try_emplace(c->getName(), c);
}
} // namespace nnet

View File

@ -0,0 +1,7 @@
#include "nnet/Visitor/InputVisitor.h"
namespace nnet {
void InputVisitor::visit_(const Tensor &c) { inputs.emplace_back(c); }
} // namespace nnet

View File

@ -0,0 +1,174 @@
#include "nnet/Visitor/Interpreter.h"
#include "nnet/Visitor/GetTensorsVisitor.h"
#include "nnet/expr.h"
namespace nnet {
using ttype = Interpreter::ttype; // Test data type
using rtype = Interpreter::rtype; // Return data type
using Position = Interpreter::Position;
using Inputs = Interpreter::Inputs;
using Iteration = Interpreter::Iteration;
Inputs Interpreter::genInputStartingFromZero(const RangeOp &range) {
Inputs inputs;
GetTensorsVisitor getTensorsVisitor;
auto tensors = getTensorsVisitor.get(range);
for (const auto &[name, tensor] : tensors) {
auto data = make_ref<vector<int>>(tensor->getSize());
for (ssize_t i = 0; i < tensor->getSize(); i++) {
data->at(i) = i;
}
inputs.emplace(name, data);
}
return inputs;
}
Interpreter::Interpreter(RangeOp range, int _verbose)
: Interpreter(genInputStartingFromZero(range), _verbose){};
rtype Interpreter::visit_(const Constant &c) { return c->getValue(); }
rtype Interpreter::visit_(const BinaryOp &c) {
rtype valueL = dispatch(c->getLhs()), valueR = dispatch(c->getRhs());
switch (c->getOpType()) {
case OpType::Add:
return valueL + valueR;
case OpType::Mul:
return valueL * valueR;
case OpType::Div:
nnet_assert(valueR > 0, "Negative divisor is ill-defeind");
return valueL / valueR;
case OpType::Mod:
nnet_assert(valueR > 0, "Negative divisor is ill-defeind");
return valueL % valueR;
case OpType::Sub:
return valueL - valueR;
default:
nnet_unimplemented_halt();
return -1;
}
}
rtype Interpreter::visit_(const RangeOp &c) {
rtype ret = 0;
iterations.emplace_back();
// loop
auto loopRanges = c->getLoopVarRanges();
assert(positions.back().size() == loopRanges.size());
auto paddings = c->getPaddings();
for (int i = 0, iEnd = loopRanges.size(); i < iEnd; i++) {
int left = loopRanges[i].second.first;
int right = loopRanges[i].second.second;
int padding = paddings[i];
int element = positions.back()[i];
if (0 < padding) {
nnet_assert(left - padding <= element, "Out of range");
nnet_assert(element < right + padding, "Out of range");
if (left <= element && element < right) {
iterations.back()[loopRanges[i].first] = positions.back()[i];
} else {
iterations.pop_back();
return 0;
}
} else {
nnet_assert(left <= element, "Out of range");
nnet_assert(element < right, "Out of range");
iterations.back()[loopRanges[i].first] = positions.back()[i];
}
}
// sum
auto sumVarRanges = c->getSumVarRanges();
int nSumIters = sumVarRanges.size();
if (0 < nSumIters) {
vector<int> sumIterValues(nSumIters);
for (const auto &[var, range] : sumVarRanges) {
sumIterValues.emplace_back(range.first);
nnet_assert(range.first < range.second, "No empty range");
}
// Enumerate all values of sum iterator
do {
for (int i = 0; i < nSumIters; i++)
iterations.back()[sumVarRanges[i].first] = sumIterValues[i];
ret += dispatch(c->getSummand());
// Increase with carry to enumerate sum iterators
sumIterValues[nSumIters - 1]++;
for (int i = nSumIters - 1; 0 < i; i--) {
if (sumIterValues[i] == sumVarRanges[i].second.second) {
sumIterValues[i] = sumVarRanges[i].second.first;
sumIterValues[i - 1]++;
}
}
} while (sumIterValues[0] < sumVarRanges[0].second.second);
} else {
ret += dispatch(c->getSummand());
}
iterations.pop_back();
return ret;
}
rtype Interpreter::visit_(const Subscript &c) {
int ret = 0;
vector<int> idx;
auto sub = c->getIndex();
for (int i = 0, iEnd = sub.size(); i < iEnd; i++) {
idx.emplace_back(dispatch(sub[i]));
}
auto obj = c->getObject();
if (obj->getType() == NodeType::RangeOpNodeType) {
positions.emplace_back(idx);
ret = dispatch(obj);
positions.pop_back();
} else if (obj->getType() == NodeType::TensorNodeType) {
auto tensor = as<TensorNode>(obj);
const auto &data = inputs[tensor->getName()];
ret = tensor->getData(data, idx);
} else
assert(false);
return ret;
}
rtype Interpreter::visit_(const Var &c) { return iterations.back()[c]; }
rtype Interpreter::visit_(const Tensor &c) {
nnet_unimplemented_halt();
return -1;
}
vector<rtype> Interpreter::interpret(const Expr &expr,
const vector<Position> &poses) {
vector<rtype> ret;
for (const auto &pos : poses) {
positions.emplace_back(pos);
ret.emplace_back(dispatch(expr));
positions.pop_back();
}
return ret;
}
vector<rtype> Interpreter::interpretUniformSample(const RangeOp &range,
int nPoses) {
vector<Interpreter::Position> poses;
const auto &rangeShape = range->getOutputShape();
for (int i = 0; i < nPoses; ++i) {
Interpreter::Position pos(range->getNumOutputDims(), 0);
ssize_t t = range->getOutputSize() / nPoses * i;
for (int j = range->getNumOutputDims() - 1; j >= 0; --j) {
int extent = rangeShape[j];
pos[j] = t % extent;
t /= extent;
}
poses.emplace_back(pos);
}
return interpret(range, poses);
}
vector<rtype> Interpreter::interpretAllOutput(const RangeOp &range) {
return interpretUniformSample(range, range->getOutputSize());
}
} // namespace nnet

View File

@ -0,0 +1,46 @@
#include "nnet/Visitor/MatchReshapeVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
bool MatchReshapeVisitor::visit_(const RangeOp &memboundRangeOp) {
// Merge nested stages
auto rangeOp =
as<RangeOpNode>(MergeMemboundMutator({memboundRangeOp}).merge());
assert(rangeOp);
auto sub = as<SubscriptNode>(rangeOp->getSummand());
if (!sub)
return false;
auto sumRanges = rangeOp->getSumVarRanges();
for (auto const &[var, range] : sumRanges) {
if (range.second - range.first != 1)
return false;
}
const auto objectRanges = sub->getObjectRangesWithoutPaddings();
const auto indices = sub->getIndex();
Expr indexExpr;
int stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
indexExpr = indexExpr + indices.at(i) * stride;
stride *= (objectRanges.at(i).second - objectRanges.at(i).first);
}
SimplifyExprVisitor simplifyExprVisitor;
simplifyExprVisitor.simplify(indexExpr);
auto exprStrides = simplifyExprVisitor.getStrides();
auto varRanges = rangeOp->getLoopVarRanges();
stride = 1;
// compare strides of variables in RangeOP and index
for (auto i = varRanges.rbegin(); i != varRanges.rend(); ++i) {
const bool alwaysZero = i->second.first == 0 && i->second.second == 1;
if (!alwaysZero && exprStrides[i->first] != stride)
return false;
stride *= (i->second.second - i->second.first);
}
return true;
}
} // namespace nnet

View File

@ -0,0 +1,56 @@
#include "nnet/Visitor/MatchTableVisitor.h"
namespace nnet {
void MatchTableVisitor::visit_(const BinaryOp &c, const Tensor &tensor, int dim,
optional<int> stride) {
if (c->getOpType() == OpType::Add) {
dispatch(c->getLhs(), tensor, dim, stride);
dispatch(c->getRhs(), tensor, dim, stride);
} else if (c->getOpType() == OpType::Sub) {
dispatch(c->getLhs(), tensor, dim, stride);
if (stride)
*stride = -*stride;
dispatch(c->getRhs(), tensor, dim, stride);
} else if (c->getOpType() == OpType::Mul) {
const optional<int> &lStride = subexprStride[c->getLhs().get()];
const optional<int> &rStride = subexprStride[c->getRhs().get()];
optional<int> lCurStride =
(stride && rStride) ? optional(*stride * *rStride) : nullopt;
optional<int> rCurStride =
(stride && lStride) ? optional(*stride * *lStride) : nullopt;
dispatch(c->getLhs(), tensor, dim, lCurStride);
dispatch(c->getRhs(), tensor, dim, rCurStride);
} else {
hasUnsupportedOp = true;
}
}
void MatchTableVisitor::visit_(const Subscript &c, const Tensor &tensor,
[[maybe_unused]] int dim,
[[maybe_unused]] optional<int> stride) {
assert(!tensor); // Should not be set until visit a tensor
auto object = as<TensorNode>(c->getObject());
assert(object);
tensors.emplace_back(object);
int currentStride = 1;
for (int i = (int)c->getDims() - 1; i >= 0; --i) {
this->dispatch(c->getIndex(i), object, i, currentStride);
currentStride *= object->getShape(i);
}
}
void MatchTableVisitor::visit_(const Var &c, const Tensor &tensor, int dim,
optional<int> stride) {
appearance.try_emplace(c);
appearance[c].emplace_back(pair(tensor, dim));
strideTable[c].emplace_back(tensor.get(), dim, stride.value());
}
void MatchTableVisitor::visit_([[maybe_unused]] const Constant &c,
[[maybe_unused]] const Tensor &tensor,
[[maybe_unused]] int dim,
[[maybe_unused]] optional<int> stride) {
return;
}
} // namespace nnet

View File

@ -0,0 +1,100 @@
#include "nnet/Visitor/MatmulTransposeMutator.h"
namespace nnet {
VecExpr MatmulTransposeMutator::transpose(const Tensor &tensor) {
assert(tensor->getDims() == 2);
const auto matmul = as<MatmulNode>(tensor->getSource());
VecExpr ret;
for (int i = 1; i < 8; ++i) {
// Whether really transpose/swap AB. transa/b are the arguments for gemm
const int Atrans = (i & 1) > 0;
const int Btrans = (i & 2) > 0;
const int ABswap = (i & 4) > 0;
auto newShape = tensor->getShape();
auto newPaddings = tensor->getPaddings();
auto [b, m, n, k, transa, transb] = matmul->getArgs();
auto inputs = matmul->getInputs();
transa ^= Atrans;
transb ^= Btrans;
// build input transpose
if (Atrans)
inputs[0] = transposeInput(inputs[0]);
if (Btrans)
inputs[1] = transposeInput(inputs[1]);
if (ABswap) {
std::swap(inputs[0], inputs[1]);
std::swap(m, n);
std::swap(transa, transb);
std::swap(newShape[0], newShape[1]);
std::swap(newPaddings[0], newPaddings[1]);
transa ^= 1;
transb ^= 1;
}
// build new Gemm Routine and Tensor
// HACK: trivially wrap the source to generate different hash
auto _va = make_ref<VarNode>("transA");
auto _vb = make_ref<VarNode>("transB");
auto _vc = make_ref<VarNode>("swapAB");
auto fakeSub = makeSubscript(matmul->getExpr(), {_va, _vb});
auto fakeRangeWrapperForHackHash =
makeRangeOperator({{_va, {0, Atrans + 100}},
{_vb, {0, Btrans + 100}},
{_vc, {0, ABswap + 100}}},
{}, fakeSub);
Matmul newMatmul =
make_ref<MatmulNode>(fakeRangeWrapperForHackHash, inputs[0],
inputs[1], b, m, n, k, transa, transb);
auto newTensor = makeTensor(derivator.newTensorName(), newShape,
newPaddings, newMatmul);
// build output transpose
if (ABswap) {
vector<Var> vars{derivator.getNewVar(), derivator.getNewVar()};
auto sub = makeSubscript(newTensor, {vars[1], vars[0]});
vector<VarRangePair> loopVRs;
// Sicne inputs array may be swaped, use the orignal tensor shape
for (int i = 0; i < 2; ++i) {
loopVRs.emplace_back(vars[i], Range(0, tensor->getShape(i)));
}
auto rangeOp = makeRangeOperator(loopVRs, {}, sub);
ret.emplace_back(rangeOp);
} else
ret.emplace_back(newTensor);
}
return ret;
}
Tensor MatmulTransposeMutator::transposeInput(const Tensor &tensor) {
Tensor ret;
if (auto ew = as<ElementWiseNode>(tensor->getSource())) {
auto rangeOp = as<RangeOpNode>(tensor->getSource()->getExpr());
assert(rangeOp);
assert(rangeOp->getNumOutputDims() == 2);
auto loopVRs = rangeOp->getLoopVarRanges();
std::swap(loopVRs[0], loopVRs[1]);
// If there are paddings, the inner stage paddings should be removed
assert(!rangeOp->hasPaddings());
// auto paddings = rangeOp->getPaddings();
// std::swap(paddings[0], paddings[1]);
auto sub = makeSubscript(rangeOp, {loopVRs[1].first, loopVRs[0].first});
auto newRangeOp = makeRangeOperator(loopVRs, {}, sub);
// ElementWise newElementWise = make_ref<ElementWiseNode>(*ew);
auto outputShape = ew->getOutputShape();
std::swap(outputShape[0], outputShape[1]);
auto newElementWise =
make_ref<ElementWiseNode>(newRangeOp, ew->getInputs(), outputShape);
auto tensorShape = tensor->getShape();
auto tensorPaddings = tensor->getPaddings();
std::swap(tensorShape[0], tensorShape[1]);
std::swap(tensorPaddings[0], tensorPaddings[1]);
ret = makeTensor(derivator.newTensorName(), tensorShape, tensorPaddings,
newElementWise);
// } else if (!tensor->getSource()) {
} else
nnet_unimplemented_halt();
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,135 @@
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/ReplaceNodeMutator.h"
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
// FIXME: fix empty expression in membound
assert(kernels.size() >= 1);
if (checkEmpty()) {
if (allowEmptyMembound)
return nullptr;
else
nnet_assert(false, "Empty membound expression");
}
// Nesting stages
auto expr = dispatch(kernels.back());
// Fusing stages
bool merged = false;
do {
merged = false;
RangeOp curRangeOp;
for (Expr *curExpr = &expr;
curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
auto curRangeOp = as<RangeOpNode>(*curExpr);
assert(CheckOOBVisitor().checkRangeOp(curRangeOp) == false);
auto summand = curRangeOp->getSummand();
if (auto subscriptOp = as<SubscriptNode>(summand)) {
if (auto mergedExpr = rule4StageMerging(*curExpr, true)) {
// dbg(*curExpr, mergedExpr);
*curExpr = mergedExpr;
merged = true;
break;
}
curExpr = subscriptOp->getObjectPtr();
nnet_assert(*curExpr != nullptr, __LINE__);
} else if (auto funcOp = as<FuncNode>(summand)) {
// Relu({...}[i,j])
curExpr = funcOp->getObject()->getObjectPtr();
} else
nnet_unimplemented_halt();
}
} while (merged);
return expr;
}
bool MergeMemboundMutator::checkEmpty() {
for (const auto &k : kernels) {
if (k == nullptr)
return true;
}
return false;
}
Expr MergeMemboundMutator::visit_(const Tensor &c) {
if (curDepth > 0)
return dispatch(kernels[--curDepth]);
else {
// Reach the last tensor, return it to reconstruct the total tree
return c;
}
}
Expr MergeMemboundMutator::rule4StageMerging(Expr &rCur,
bool mergeStageWithCalc) {
auto rangeOp0 = as<RangeOpNode>(rCur);
const Subscript &sub0 = as<SubscriptNode>(rangeOp0->getSummand());
if (!sub0)
return nullptr;
const auto &rangeOp1 = as<RangeOpNode>(sub0->getObject());
if (!rangeOp1)
return nullptr;
const auto &sub1 = as<SubscriptNode>(rangeOp1->getSummand());
if (!sub1)
return nullptr;
// merge stage with calculation only when mergeStageWithCalc=true
if (!mergeStageWithCalc && !rangeOp1->getSumVarRanges().empty())
return nullptr;
// Only propogate paddings in perfect nested dimension
if (rangeOp1->hasPaddings()) {
auto oldTensor = as<TensorNode>(sub1->getObject());
if (!oldTensor) {
nnet_unimplemented_continue();
return nullptr;
}
}
// repalce variables: iters of rangeOp1 repalced by indexes of sub0
map<string, pair<Expr, Expr>> varMapping;
assert(sub0->getDims() == rangeOp1->getLoopVarRanges().size());
for (size_t i = 0; i < sub0->getDims(); ++i) {
varMapping[rangeOp1->getLoopVar(i)->getName()] =
pair(rangeOp1->getLoopVar(i), sub0->getIndex(i));
}
ReplaceVariable replaceVariable{varMapping};
auto merged = make_ref<RangeOpNode>(*rangeOp0);
merged->setSummand(replaceVariable(sub1));
// a naive approach to propogate paddings
if (rangeOp1->hasPaddings()) {
auto oldTensor = as<TensorNode>(sub1->getObject());
auto newTensor = make_ref<TensorNode>(*oldTensor);
for (int i = 0; i < rangeOp1->getNumOutputDims(); ++i) {
if (rangeOp1->getPaddings(i) == 0)
continue;
auto loopVar = rangeOp1->getLoopVar(i);
// FIXME: in fact this var should not appear in other index as well,
// which may result in OOB
bool findSingleVarAsIndex = false;
for (size_t subIndexID = 0; subIndexID < sub1->getDims();
++subIndexID) {
auto index = sub1->getIndex(subIndexID);
if (auto indexVar = as<VarNode>(index);
indexVar && (indexVar->equal(loopVar))) {
newTensor->setPadding(subIndexID,
newTensor->getPadding(subIndexID) +
rangeOp1->getPaddings(i));
findSingleVarAsIndex = true;
}
}
if (!findSingleVarAsIndex) {
nnet_unimplemented_continue();
return nullptr;
}
}
merged = as<RangeOpNode>(
ReplaceNodeMutator().replace(merged, oldTensor.get(), newTensor));
assert(merged != nullptr);
}
// Merge inner stage sums
if (!rangeOp1->getSumVarRanges().empty())
merged->setSumIterator(rangeOp1->getSumVarRanges());
return merged;
}
} // namespace nnet

View File

@ -0,0 +1,124 @@
#include "nnet/Visitor/PatternMatcher.h"
#include "nnet/ReplaceKit.h"
#include "nnet/Visitor/MatmulTransposeMutator.h"
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
PatternMatcher::PatternMatcher(Derivator &derivator, const RangeOp &cur)
: Functor(false), derivator(derivator), hasNonZeroRange(false),
originalCur(cur) {
for (const auto &[var, range] : cur->getLoopVarRanges())
if (range.first != 0)
hasNonZeroRange = true;
}
RangeOp PatternMatcher::getOffsetCur() {
if (!hasNonZeroRange)
return originalCur;
vector<Var> itersFromNonZero;
VecExpr psis;
vector<VarRangePair> newLoopVarRanges;
for (const auto &[var, range] : originalCur->getLoopVarRanges()) {
if (range.first == 0) {
newLoopVarRanges.emplace_back(var, range);
} else {
auto newVar = derivator.getNewVar();
newLoopVarRanges.emplace_back(newVar,
pair(0, range.second - range.first));
itersFromNonZero.emplace_back(var);
psis.emplace_back(newVar + range.first);
}
}
auto newSummand = ReplaceKit::replaceMultipleExprs(
originalCur->getSummand(), itersFromNonZero, psis);
return makeRangeOperator(newLoopVarRanges, originalCur->getSumVarRanges(),
newSummand);
}
VecExpr PatternMatcher::matchKernel(const Pattern &pattern,
const RangeOp &rangeOp,
IteratorTable &exprIT) {
VecExpr ret;
if (pattern.getNumTensors() != (int)exprIT.getNumTensors())
return ret;
// Whether enable tensor permutation
if (false) {
const int nInputs = pattern.getNumInputs();
vector<int> tensorMap; // [tensors Index] -> pattern tensor ID
for (int i = 0; i < nInputs; ++i)
tensorMap.emplace_back(i);
do {
exprIT.buildTable(tensorMap);
auto matched = matchKernelWithTensorMap(pattern, rangeOp, exprIT);
if (matched)
ret.emplace_back(matched);
} while (std::next_permutation(tensorMap.begin(), tensorMap.end()));
} else {
exprIT.buildTableWithDefaultMap();
auto matched = matchKernelWithTensorMap(pattern, rangeOp, exprIT);
if (matched)
ret.emplace_back(matched);
}
// Generate 8 variants of gemm
if (true) // Disabled for debug
if (!ret.empty() && dynamic_cast<const MatmulPattern *>(&pattern)) {
auto tensor = as<TensorNode>(ret[0]);
auto transposeds =
MatmulTransposeMutator(derivator).transpose(tensor);
for (const auto &transposed : transposeds)
ret.emplace_back(transposed);
}
return ret;
}
Expr PatternMatcher::matchKernelWithTensorMap(const Pattern &pattern,
const RangeOp &rangeOp,
IteratorTable &exprIT) {
auto mismatches = exprIT.matchPatternIT(pattern);
if (!mismatches.empty())
return nullptr;
const auto &[tensorMap_r, iterToRange_r] = exprIT.getReverseMap();
// // TODO: check OOB error
// for (int tensorID = 0; tensorID < pattern.getNumInputs(); ++tensorID) {
// if (!checkIndexOutOfBound(pattern.getIterInTensorDim(tensorID),
// tensorMap_r[tensorID], iterToRange_r))
// return nullptr;
// }
// matched! build expr for ret;
return pattern.buildExpr(rangeOp, tensorMap_r, iterToRange_r,
derivator.newTensorName(), exprIT);
}
VecExpr PatternMatcher::applyWrapper(const VecExpr &exprs) {
if (!hasNonZeroRange)
return exprs;
VecExpr ret, indexes;
for (const auto &[var, range] : originalCur->getLoopVarRanges()) {
if (range.first == 0) {
indexes.emplace_back(var);
} else {
hasNonZeroRange = true;
indexes.emplace_back(var - range.first);
}
}
for (auto &expr : exprs) {
auto newSub = makeSubscript(expr, indexes);
ret.emplace_back(makeRangeOperator(originalCur->getLoopVarRanges(), {},
newSub, originalCur->getPaddings()));
}
return ret;
}
VecExpr PatternMatcher::matchWithPattern(const RangeOp &rangeOp,
const Pattern &pattern) {
IteratorTable exprIT;
if (!exprIT.analyzeExpr(rangeOp))
return {};
return matchKernel(pattern, rangeOp, exprIT);
}
} // namespace nnet

View File

@ -0,0 +1,56 @@
#include "nnet/Visitor/RangeMagnifyVisitor.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
Expr RangeMagnifyVisitor::visit_(const RangeOp &c) {
if (auto expr = Mutator::visit_(c)) {
auto ret = as<RangeOpNode>(expr);
ret->setSumIterator(newSumVarRanges);
return ret;
} else
return nullptr;
}
Expr RangeMagnifyVisitor::visit_(const Subscript &c) {
auto tensor = as<TensorNode>(c->getObject());
if (!tensor)
return nullptr;
// Check new ranges
bool paddingMagnify = false;
vector<Range> tensorRanges = c->getObjectRangesWithPaddings();
vector<int> paddingsDelta(tensorRanges.size(), 0);
for (int i = 0; i < (int)c->getDims(); ++i) {
auto indexRange =
SimplifyExprVisitor().getExprRange(c->getIndex(i), newRangeOp);
if (!indexRange.has_value())
return nullptr;
int delta = max(tensorRanges[i].first - indexRange->first,
indexRange->second - tensorRanges[i].second);
if (delta > 0) {
paddingMagnify = true;
paddingsDelta[i] = delta;
}
}
if (!paddingMagnify)
return nullptr;
// Create new tensor. Direct add paddings to the Tensor.
auto newTensor = make_ref<TensorNode>(*tensor);
for (int i = 0; i < newTensor->getDims(); ++i)
newTensor->setPadding(i, newTensor->getPadding(i) + paddingsDelta[i]);
auto newSub = make_ref<SubscriptNode>(*c);
newSub->setObject(newTensor);
return newSub;
}
RangeOp
RangeMagnifyVisitor::magnify(const RangeOp &root,
const vector<VarRangePair> &_newSumVarRanges) {
newSumVarRanges = _newSumVarRanges;
newRangeOp = make_ref<RangeOpNode>(*root);
newRangeOp->setSumIterator(newSumVarRanges);
const auto &newCur = as<RangeOpNode>(dispatch(root));
return newCur;
}
} // namespace nnet

View File

@ -0,0 +1,133 @@
#include "nnet/Visitor/RangeRelaxFunctor.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
RangeMap RangeRelaxFunctor::visit_(const BinaryOp &c) {
if (verbose)
dbg(*c);
if (c->getOpType() == OpType::Mul)
return intersectRangeMaps(dispatch(c->getLhs()), dispatch(c->getRhs()));
nnet_unimplemented_halt();
return RangeMap();
}
RangeMap RangeRelaxFunctor::visit_(const RangeOp &c) {
if (verbose)
dbg(*c);
return dispatch(c->getSummand());
}
RangeMap RangeRelaxFunctor::visit_(const Subscript &c) {
if (verbose)
dbg(*c);
RangeMap ret;
const auto &tensor = as<TensorNode>(c->getObject());
for (size_t i = 0; i < c->getDims(); ++i) {
const int padding = tensor->getPadding(i);
const int length = tensor->getShape(i);
if (!padding)
continue;
// TODO assert access within padding+length
// If the index is a single Var
if (auto var = as<VarNode>(c->getIndex(i))) {
ret = intersectRangeMaps(ret, {{var, {0, length}}});
} else { // If the index is linear expression
const auto &[strides, offset] =
SimplifyExprVisitor().getStridesConstant(c->getIndex(i));
// // Calculate the sum of ranges of all iters in negative value
// Range allRange(-offset, -offset);
// for (const auto &[iter, stride] : strides) {
// auto iterRange = rangeOp->getVarRange(iter).second;
// if (stride > 0) {
// allRange.first -= stride * (iterRange.second - 1);
// allRange.second -= stride * iterRange.first;
// } else {
// allRange.first += stride * iterRange.first;
// allRange.second += stride * (iterRange.second - 1);
// }
// dbg(iter, stride, iterRange, allRange);
// }
// dbg(allRange);
// // Calculate the meaningful ranges for each iter
// for (const auto &[iter, stride] : strides) {
// auto iterRange = rangeOp->getVarRange(iter).second;
// auto rangeExceptThis{allRange};
// if (stride > 0) {
// rangeExceptThis.first += stride * (iterRange.second - 1);
// rangeExceptThis.second += stride * iterRange.first;
// } else {
// rangeExceptThis.first -= stride * iterRange.first;
// rangeExceptThis.second -= stride * (iterRange.second -
// 1);
// }
// // Meaningful calculation range for current iter
// int l, r;
// if (stride > 0) {
// // l = (0 - rangeExceptThis.second + stride - 1) /
// stride;
// // r = (length - rangeExceptThis.first) / stride;
// l = (0 - rangeExceptThis.second + stride - 1) / stride;
// r = (length - 1 - rangeExceptThis.first) / stride + 1;
// } else {
// nnet_unimplemented_continue();
// continue;
// }
// dbg(iter, stride, iterRange, l, r);
// ret = intersectRangeMaps(ret, {{iter, {l, r}}});
// }
// Calculate the sum of ranges of all iters in negative value
Range allRange(offset, offset);
for (const auto &[iter, stride] : strides) {
auto iterRange = rangeOp->getVarRange(iter).second;
if (stride > 0) {
allRange.first += stride * iterRange.first;
allRange.second += stride * (iterRange.second - 1);
} else {
allRange.first += stride * (iterRange.second - 1);
allRange.second += stride * iterRange.first;
}
// dbg(iter, stride, iterRange, allRange);
}
// Calculate the meaningful ranges for each iter
for (const auto &[iter, stride] : strides) {
auto iterRange = rangeOp->getVarRange(iter).second;
auto rangeExceptThis{allRange};
if (stride > 0) {
rangeExceptThis.first -= stride * iterRange.first;
rangeExceptThis.second -= stride * (iterRange.second - 1);
} else {
rangeExceptThis.first -= stride * (iterRange.second - 1);
rangeExceptThis.second -= stride * iterRange.first;
}
// Meaningful calculation range for current iter
int l, r;
if (stride > 0) {
// l = (0 - rangeExceptThis.second + stride - 1) / stride;
// r = (length - rangeExceptThis.first) / stride;
l = (0 - rangeExceptThis.second + stride - 1) / stride;
r = (length - 1 - rangeExceptThis.first) / stride + 1;
} else {
nnet_unimplemented_continue();
continue;
}
ret = intersectRangeMaps(ret, {{iter, {l, r}}});
}
}
}
return ret;
}
RangeMap RangeRelaxFunctor::intersectRangeMaps(const RangeMap &a,
const RangeMap &b) {
RangeMap ret(a);
for (const auto &[k, v] : b) {
if (!ret.count(k))
ret[k] = v;
else {
auto const &u = ret[k];
ret[k] = {max(u.first, v.first), min(u.second, v.second)};
}
}
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,23 @@
#include "nnet/Visitor/ReplaceNodeMutator.h"
namespace nnet {
Expr ReplaceNodeMutator::visit_(const Subscript &c) {
if (c.get() == target)
return replacement;
return Mutator::visit_(c);
}
Expr ReplaceNodeMutator::visit_(const Tensor &c) {
if (c.get() == target)
return replacement;
return nullptr;
}
Expr ReplaceNodeMutator::replace(const Expr &root, ExprNode *_target,
const Expr &_replace) {
target = _target;
replacement = _replace;
return dispatch(root);
}
} // namespace nnet

View File

@ -0,0 +1,50 @@
#include "nnet/Visitor/ReplaceVariable.h"
namespace nnet {
Expr ReplaceVariable::visit_(const BinaryOp &c) {
if (verbose)
dbg(*c);
if (auto mutate = match(c); mutate)
return mutate;
else
return Mutator::visit_(c);
}
Expr ReplaceVariable::visit_(const Var &c) {
if (verbose)
dbg(*c);
if (auto mutate = match(c); mutate)
return mutate;
else
return Mutator::visit_(c);
}
Expr ReplaceVariable::visit_(const RangeOp &c) {
if (verbose)
dbg(*c);
return nullptr;
}
void ReplaceVariable::set(VecExpr _pattern, VecExpr _replacement) {
patterns = _pattern;
replacements = _replacement;
for (size_t i = 0; i < patterns.size(); ++i) {
auto hash = patterns[i]->hash();
assert(patternHash.count(hash) == 0);
patternHash[hash] = i;
}
}
Expr ReplaceVariable::match(const Expr &c) {
auto hash = c->hash();
if (auto it = patternHash.find(hash); it != patternHash.end()) {
const auto &i = it->second;
if (verbose)
dbg("Match", *c, *patterns[i], c->hash());
return replacements[i];
}
return nullptr;
}
} // namespace nnet

View File

@ -0,0 +1,245 @@
#include "nnet/Visitor/Serializer.h"
#include "nlohmann/json.hpp"
#include "nnet/expr.h"
#include <fstream>
namespace nnet {
int Serializer::id = 0;
Serializer::Serializer(int _verobse)
: Functor(_verobse), jPtr(std::make_unique<json>()), j(*jPtr) {}
Serializer::~Serializer() = default;
string Serializer::visit_(const Constant &c) {
string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["val"] = c->getValue();
return key;
}
string Serializer::visit_(const BinaryOp &c) {
string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["opType"] = (int)c->getOpType();
j[key]["lhs"] = dispatch(c->getLhs());
j[key]["rhs"] = dispatch(c->getRhs());
return key;
}
string Serializer::visit_(const RangeOp &c) {
string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["paddings"] = c->getPaddings();
j[key]["summand"] = dispatch(c->getSummand());
for (auto var : c->getLoopVarRanges()) {
j[key]["loopVarRanges"][var.first->getName()] = var.second;
}
for (auto var : c->getSumVarRanges()) {
j[key]["sumVarRanges"][var.first->getName()] = var.second;
}
return key;
}
string Serializer::visit_(const Subscript &c) {
string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["subExprsNum"] = c->getDims();
j[key]["object"] = dispatch(c->getObject());
vector<string> indexes;
for (auto index : c->getIndex()) {
indexes.emplace_back(dispatch(index));
}
j[key]["indexes"] = indexes;
return key;
}
string Serializer::visit_(const Var &c) {
string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["name"] = c->getName();
return key;
}
string Serializer::visit_(const Tensor &c) {
const string key = std::to_string(id++);
j[key]["type"] = c->getType();
j[key]["name"] = c->getName();
j[key]["shape"] = c->getShape();
j[key]["paddings"] = c->getPaddings();
const auto &routine = c->getSource();
j[key]["source"] = dispatchRoutine(routine);
return key;
}
bool Serializer::serialize(const Expr &expr, const string &filePath,
const string &msg) {
// Metadata
j["Version"] = VERSION;
j["Msg"] = msg;
// Expressions and routines
id = 0;
dispatch(expr);
std::ofstream fout(filePath);
fout << std::setw(4) << j << std::endl;
return true;
}
string Serializer::dispatchRoutine(const Routine &c) {
if (!c)
return "-1";
const string key = std::to_string(id++);
j[key]["type"] = c->getType();
vector<string> inputs;
for (const auto &tensor : c->getInputs())
inputs.emplace_back(dispatch(tensor));
j[key]["inputs"] = inputs;
if (const auto &expr = c->getExpr())
j[key]["expr"] = dispatch(expr);
else
j[key]["expr"] = "-1";
switch (c->getType()) {
case RoutineType::NoneType:
nnet_unimplemented_halt();
break;
case RoutineType::MatmulNodeType: {
j[key]["args"] = as<MatmulNode>(c)->getArgs();
break;
}
case RoutineType::ConvNodeType:
j[key]["args"] = as<ConvNode>(c)->getArgs();
break;
case RoutineType::G2bmmNodeType:
j[key]["args"] = as<G2bmmNode>(c)->getArgs();
break;
case RoutineType::GbmmNodeType:
j[key]["args"] = as<GbmmNode>(c)->getArgs();
break;
case RoutineType::ElementWiseNodeType: {
j[key]["outputShape"] = as<ElementWiseNode>(c)->getOutputShape();
break;
}
default:
nnet_unimplemented_halt();
}
return key;
}
Expr Serializer::deserialize(const string &filePath) {
std::ifstream fin(filePath);
fin >> j;
assert(j["Version"] == VERSION);
return buildExprTree("0");
}
Expr Serializer::buildExprTree(string key) {
switch (NodeType(j[key]["type"])) {
case NodeType::ConstantNodeType: {
return make_ref<ConstantNode>(j[key]["val"]);
}
case NodeType::BinaryOpNodeType: {
auto lhs = buildExprTree(j[key]["lhs"]);
auto rhs = buildExprTree(j[key]["rhs"]);
return make_ref<BinaryOpNode>(j[key]["opType"], lhs, rhs);
}
case NodeType::RangeOpNodeType: {
vector<VarRangePair> loopIters, sumIters;
for (auto &loopIter : j[key]["loopVarRanges"].items()) {
loopIters.emplace_back(
pair(make_ref<VarNode>(loopIter.key()),
pair(loopIter.value()[0], loopIter.value()[1])));
}
for (auto &sumIter : j[key]["sumVarRanges"].items()) {
sumIters.emplace_back(
pair(make_ref<VarNode>(sumIter.key()),
pair(sumIter.value()[0], sumIter.value()[1])));
}
auto summand = buildExprTree(j[key]["summand"]);
auto paddings = j[key]["paddings"].get<std::vector<int>>();
auto rangeOp = makeRangeOperator(loopIters, sumIters, summand);
rangeOp->setPaddings(paddings);
return rangeOp;
}
case NodeType::SubscriptNodeType: {
auto indexed = buildExprTree(j[key]["object"]);
VecExpr subExprs;
for (int i = 0, iEnd = j[key]["subExprsNum"]; i < iEnd; i++) {
subExprs.emplace_back(buildExprTree(j[key]["indexes"][i]));
}
return make_ref<SubscriptNode>(indexed, subExprs);
}
case NodeType::VarNodeType: {
return make_ref<VarNode>(j[key]["name"]);
}
case NodeType::TensorNodeType: {
auto source = buildRoutine(j[key]["source"]);
return make_ref<TensorNode>(j[key]["name"], j[key]["shape"],
j[key]["paddings"], source);
}
default: {
nnet_unimplemented_halt();
break;
}
}
return nullptr;
}
Routine Serializer::buildRoutine(string key) {
if (key == "-1")
return nullptr;
Expr expr = nullptr;
if (j[key]["expr"] != "-1")
expr = buildExprTree(j[key]["expr"]);
vector<Tensor> inputs;
for (const auto &input : j[key]["inputs"])
inputs.emplace_back(as<TensorNode>(buildExprTree(input)));
switch (RoutineType(j[key]["type"])) {
case RoutineType::NoneType:
nnet_unimplemented_halt();
break;
case RoutineType::MatmulNodeType: {
assert(inputs.size() == 2);
auto args = j[key]["args"].get<MatmulArgs>();
auto ctorArgs =
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
return make_ref_from_tuple<MatmulNode>(ctorArgs);
break;
}
case RoutineType::ConvNodeType: {
assert(inputs.size() == 2);
auto args = j[key]["args"].get<ConvArgs>();
auto ctorArgs =
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
return make_ref_from_tuple<ConvNode>(ctorArgs);
break;
}
case RoutineType::G2bmmNodeType: {
auto args = j[key]["args"].get<G2bmmArgs>();
auto ctorArgs =
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
return make_ref_from_tuple<G2bmmNode>(ctorArgs);
break;
}
case RoutineType::GbmmNodeType: {
auto args = j[key]["args"].get<GbmmArgs>();
auto ctorArgs =
std::tuple_cat(std::tie(expr, inputs[0], inputs[1]), args);
return make_ref_from_tuple<GbmmNode>(ctorArgs);
break;
}
case RoutineType::ElementWiseNodeType: {
return make_ref<ElementWiseNode>(expr, inputs, j[key]["outputShape"]);
break;
}
default:
nnet_unimplemented_halt();
}
return nullptr;
}
} // namespace nnet

View File

@ -0,0 +1,160 @@
#include "nnet/Visitor/SimplifyExprVisitor.h"
#include "nnet/Visitor/StrideVisitor.h"
namespace nnet {
pair<PtrMap<Iterator, int>, int>
SimplifyExprVisitor::getStridesConstant(const Expr &expr) {
constant = 0;
divStrides.clear();
modStrides.clear();
subexprStride = StrideVisitor().getExprStride(expr);
strides.clear();
dispatch(expr, 1);
return {strides, constant};
}
optional<Range> SimplifyExprVisitor::getExprRange(const Expr &expr,
const RangeOp &rangeOp) {
constant = 0;
divStrides.clear();
modStrides.clear();
subexprStride = StrideVisitor().getExprStride(expr);
strides.clear();
dispatch(expr, 1);
// Skip if there is divide and modulo
if (!divStrides.empty() || !modStrides.empty() || !divExprStrides.empty() ||
!modExprStrides.empty())
return {};
Range ret{constant, constant + 1};
for (const auto &[iter, stride] : strides) {
const auto &[l, r] = rangeOp->getRange(iter);
if (stride > 0) {
ret.first += l * stride;
ret.second += (r - 1) * stride;
} else {
ret.first += (r - 1) * stride;
ret.second += l * stride;
}
}
return ret;
}
PtrMap<Iterator, int> SimplifyExprVisitor::getStrides(const Expr &expr) {
return getStridesConstant(expr).first;
}
int SimplifyExprVisitor::getConstant(const Expr &expr) {
return getStridesConstant(expr).second;
}
Expr SimplifyExprVisitor::simplify(const Expr &expr) {
getStrides(expr);
Expr ret = nullptr;
// merge divide and modulo items
for (const auto &[iterDividerPair, divStride] : divStrides) {
const auto &[iter, mod] = iterDividerPair;
// mod < 0 is a marker for merging vars with negtive strides. In math,
// divider < 0 is not well-defined for mod, so it should be exist in our
// epxrs and is only a temporary state which must be simpilified now.
if (mod < 0) { // must perfectly merged.
const auto &modStride = modStrides[iterDividerPair];
assert(divStride / abs(mod) == modStride);
assert(divStride > 0);
strides.try_emplace(iterDividerPair.first, 0);
strides[iterDividerPair.first] += abs(divStride / mod);
modStrides.erase(iterDividerPair);
} else if (divStride % mod == 0 && modStrides.count(iterDividerPair)) {
const auto &modStride = modStrides[iterDividerPair];
if (divStride / mod == modStride) {
strides.try_emplace(iterDividerPair.first, 0);
strides[iterDividerPair.first] += divStride / mod;
modStrides.erase(iterDividerPair);
} else
ret = ret + divStride * (iterDividerPair.first /
iterDividerPair.second);
} else
ret = ret +
divStride * (iterDividerPair.first / iterDividerPair.second);
}
// remaining modulo items
for (const auto &[iterDividerPair, stride] : modStrides) {
ret = ret + stride * (iterDividerPair.first % iterDividerPair.second);
}
// normal constant*variable items
for (const auto &[iter, stride] : strides) {
if (stride == 0)
continue;
Expr subexpr;
if (stride == 1)
subexpr = iter;
else
subexpr = stride * iter;
ret = (ret) ? ret + subexpr : subexpr;
}
// not perfectly nested divide and modulo items
for (const auto &[iterDividerPair, stride] : divExprStrides) {
ret = ret + stride * (iterDividerPair.first / iterDividerPair.second);
}
for (const auto &[iterDividerPair, stride] : modExprStrides) {
ret = ret + stride * (iterDividerPair.first % iterDividerPair.second);
}
ret = ret + constant;
return ret ? ret : make_ref<ConstantNode>(0);
}
void SimplifyExprVisitor::visit_(const BinaryOp &c, optional<int> stride) {
if (verbose)
dbg(c);
if (c->getOpType() == OpType::Add) {
dispatch(c->getLhs(), stride);
dispatch(c->getRhs(), stride);
} else if (c->getOpType() == OpType::Sub) {
dispatch(c->getLhs(), stride);
if (stride)
*stride = -*stride;
dispatch(c->getRhs(), stride);
} else if (c->getOpType() == OpType::Mul) {
const optional<int> &lStride = subexprStride[c->getLhs().get()];
const optional<int> &rStride = subexprStride[c->getRhs().get()];
optional<int> lCurStride =
(stride && rStride) ? optional(*stride * *rStride) : nullopt;
optional<int> rCurStride =
(stride && lStride) ? optional(*stride * *lStride) : nullopt;
dispatch(c->getLhs(), lCurStride);
dispatch(c->getRhs(), rCurStride);
} else if (c->getOpType() == OpType::Mod) {
const auto &param = c->getModDivParameter();
if (param.has_value()) {
modStrides.try_emplace(*param, 0);
modStrides[*param] += stride.value();
} else {
const auto &paramExpr = c->getModDivExpr();
modExprStrides.try_emplace(paramExpr, 0);
modExprStrides[paramExpr] += stride.value();
}
} else if (c->getOpType() == OpType::Div) {
const auto &param = c->getModDivParameter();
if (param.has_value()) {
divStrides.try_emplace(*param, 0);
divStrides[*param] += stride.value();
} else {
const auto &paramExpr = c->getModDivExpr();
divExprStrides.try_emplace(paramExpr, 0);
divExprStrides[paramExpr] += stride.value();
}
} else
nnet_unimplemented_halt();
}
void SimplifyExprVisitor::visit_(const Var &c, optional<int> stride) {
if (verbose)
dbg(c);
strides.try_emplace(c);
strides[c] += stride.value();
}
void SimplifyExprVisitor::visit_(const Constant &c, optional<int> stride) {
if (stride.has_value())
constant += stride.value() * c->getValue();
}
} // namespace nnet

View File

@ -0,0 +1,30 @@
#include "nnet/Visitor/SimplifyFormulaMutator.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
Expr SimplifyFormulaMutator::visit_(const Subscript &c) {
++nSubscripts;
if (verbose)
dbg(*c);
bool modified = false;
auto ret = make_ref<SubscriptNode>(*c);
for (size_t i = 0; i < ret->getDims(); ++i) {
const auto &e = ret->getIndex(i);
if (const auto &mutated = SimplifyExprVisitor().simplify(e)) {
modified = true;
ret->setIndex(i, mutated);
}
}
return (modified) ? ret : nullptr;
}
Expr SimplifyFormulaMutator::simplify(const Expr &expr) {
nSubscripts = 0;
const auto &ret = dispatch(expr);
nnet_assert(nSubscripts > 0,
"Subscript NOT found. Use SimplifyFormulaMutator?");
return (ret) ? ret : expr;
}
} // namespace nnet

View File

@ -0,0 +1,59 @@
#include "nnet/Visitor/StrideVisitor.h"
namespace nnet {
optional<int> StrideVisitor::visit_(const Subscript &c) {
if (verbose)
dbg(*c);
auto object = as<TensorNode>(c->getObject());
assert(object);
for (int i = (int)c->getDims() - 1; i >= 0; --i)
this->dispatch(c->getIndex(i));
return {};
}
optional<int> StrideVisitor::visit_(const Var &c) {
if (verbose)
dbg(*c);
// assert(subexprStride.count(&c) == 0);
subexprStride[c.get()] = {};
return {};
}
optional<int> StrideVisitor::visit_(const Constant &c) {
if (verbose)
dbg(*c);
optional ret{c->getValue()};
// assert(subexprStride.count(&c) == 0);
subexprStride[c.get()] = ret;
return ret;
}
optional<int> StrideVisitor::visit_(const BinaryOp &c) {
if (verbose)
dbg(*c);
optional<int> strideL = this->dispatch(c->getLhs());
optional<int> strideR = this->dispatch(c->getRhs());
if (!strideL || !strideR)
return {};
optional<int> ret;
switch (c->getOpType()) {
case OpType::Add:
ret = optional(*strideL + *strideR);
break;
case OpType::Sub:
ret = optional(*strideL - *strideR);
break;
case OpType::Mul:
ret = optional(*strideL * *strideR);
break;
default:
nnet_unimplemented_halt();
break;
}
// assert(subexprStride.count(&c) == 0);
subexprStride[c.get()] = ret;
return ret;
}
} // namespace nnet

View File

@ -0,0 +1,156 @@
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/FullPrinterVisitor.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
namespace nnet {
constexpr int varPrefix = 11027;
constexpr int binPrefix = 11047;
constexpr int ssPrefix = 11057;
constexpr int addPrefix = 11059;
constexpr int mulPrefix = 11069;
constexpr int vecPrefix = 11071;
constexpr int tensorPrefix = 11083;
constexpr int valSuffix = 6214587;
static inline HashType hash(const HashType a, const HashType b) {
return (a * 10007 + b + 12345) % 1000000007;
}
static inline HashType hash(const std::string &s) {
HashType ret = 0;
for (auto c : s)
ret = hash(ret, c);
return ret;
}
static inline HashType hash(const OpType c) { return HashType(c); }
HashType HashVisitor::getHash(const Expr &c) { return dispatch(c); }
HashType HashVisitor::visit_(const Constant &c) {
auto val = c->getValue();
return genhash(val, valSuffix);
}
HashType HashVisitor::visit_(const BinaryOp &c) {
HashType hasha = dispatch(c->getLhs());
HashType hashb = dispatch(c->getRhs());
if (c->isSwapable()) {
if (hasha > hashb) {
std::swap(hasha, hashb);
}
}
return hash(binPrefix, hash(hash(c->getOpType()), hash(hasha, hashb)));
return 0;
}
HashType hashLoopVar(const int id, const Range &range) {
return hash(varPrefix, hash(id, hash(range.first, range.second)));
}
HashType hashSumVar(const Range &range) {
return hash(varPrefix, hash(range.first, range.second));
}
HashType HashVisitor::visit_(const RangeOp &c) {
// Identify loop variables
for (const auto &[var, range] : c->getLoopVarRanges()) {
nnet_assert(varHash.find(var) == varHash.end(),
"In HashVisiter::RangeOp invalid loop var.");
varHash[var] = hashLoopVar(nLoopVars++, range);
}
// Identify sum variables according to range
for (const auto &[var, range] : c->getSumVarRanges()) {
nnet_assert(varHash.find(var) == varHash.end(),
"In HashVisiter::RangeOp invalid sum var.");
varHash[var] = hashSumVar(range);
}
auto expr = c->getSummand();
return dispatch(expr);
}
HashType HashVisitor::visit_(const Subscript &c) {
HashType curHash = ssPrefix;
auto obj = c->getObject();
if (obj->getType() == NodeType::RangeOpNodeType) {
curHash = hash(curHash, HashVisitor().getHash(obj));
} else if (obj->getType() == NodeType::TensorNodeType) {
// TODO: hash should based on arguments
curHash = hash(curHash, dispatch(obj));
} else {
nnet_unimplemented_halt();
}
for (const auto &expr : c->getIndex()) {
if (expr->getType() == NodeType::BinaryOpNodeType) {
HashType tmp = addPrefix;
std::vector<std::pair<HashType, HashType>> coefficients;
auto seVisitor = SimplifyExprVisitor();
auto [c, x] = seVisitor.getStridesConstant(expr);
for (const auto &[key, value] : c) {
coefficients.emplace_back(varHash[key], value);
}
for (const auto &[iter, value] : seVisitor.getDivStrides()) {
nnet_assert(iter.second != 1, "invalid div expr");
coefficients.emplace_back(
hash(binPrefix, hash(varHash[iter.first], iter.second)),
value);
}
sort(coefficients.begin(), coefficients.end());
tmp = hash(tmp, x);
for (const auto &[key, value] : coefficients) {
tmp = hash(tmp, hash(mulPrefix, hash(key, value)));
}
curHash = hash(curHash, tmp);
continue;
}
if (expr->getType() == NodeType::ConstantNodeType) {
curHash = hash(curHash, dispatch(expr));
continue;
}
if (expr->getType() == NodeType::VarNodeType) {
curHash = hash(curHash, dispatch(expr));
continue;
}
nnet_unimplemented_halt();
}
return curHash;
}
HashType hashPadding(const std::vector<int> &pad) {
HashType cur = hash(vecPrefix, pad.size());
for (const auto &e : pad) {
cur = hash(cur, e);
}
return cur;
}
HashType HashVisitor::visit_(const Tensor &c) {
// TODO: remove this
// TODO: check if hash name includes padding.
if (c->getSource() == nullptr) {
return hash(tensorPrefix, genhash(c->getName()));
}
// dbg(c, c->getSource()->getExpr(), FullPrinterVisitor().print(c),
// FullPrinterVisitor().print(c->getSource()->getExpr()));
// std::cout << "Tensor: " << int(c->getSource()->getExpr()->getType())
// << std::endl;
// std::cout << "Tensor: " << c->getSource()->getExpr()->toReadable()
// << std::endl;
return hash(tensorPrefix,
HashVisitor().dispatch(c->getSource()->getExpr()));
}
HashType HashVisitor::visit_(const Var &c) {
if (varHash.find(c) == varHash.end()) {
nnet_unimplemented_halt();
return 0;
}
return varHash[c];
}
} // namespace nnet

529
src/nnet/derivator.cc Normal file
View File

@ -0,0 +1,529 @@
#include "nnet/derivator.h"
#include "nnet/Pass/MatchComputationKernel.h"
#include "nnet/Pass/MatchMemBoundKernel.h"
#include "nnet/Pass/Rule1VariableSplit.h"
#include "nnet/Pass/Rule2VariableMerging.h"
#include "nnet/Pass/Rule3StageSplit.h"
#include "nnet/Pass/Rule4StageMerging.h"
#include "nnet/Pass/Rule5RangeRelaxation.h"
#include "nnet/Pass/Rule6KenerlMatching.h"
#include "nnet/Pass/Rule7DLT.h"
#include "nnet/Pass/Rule8GuidedDLT.h"
#include "nnet/Pass/Rule90TwoStageElementWise.h"
#include "nnet/Pass/Rule91MergeStagesWithSum.h"
#include "nnet/Pass/Rule9RangeMagnify.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/CloneMutator.h"
#include "nnet/Visitor/CompareMultiFormulasVisitor.h"
#include "nnet/Visitor/CountRoutineVisitor.h"
#include "nnet/Visitor/FullPrinterVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/Visitor/Serializer.h"
#include "nnet/test.h"
namespace nnet {
class SaveStateGuard {
Derivator &derivator;
public:
SaveStateGuard(Derivator &derivator, const Expr &origin,
const string &ruleName, const string &ruleMsg = "")
: derivator(derivator) {
derivator.pushIntermediateState(origin);
derivator.pushRuleState(ruleName);
derivator.pushRuleMsg(ruleMsg);
}
~SaveStateGuard() {
derivator.popIntermediateState();
derivator.popRuleState();
derivator.popRuleMsg();
}
};
#define SetUpStateGuard() \
SaveStateGuard __guard(*this, origin.root, __FUNCTION__)
void Derivator::dfs(Formula &origin, int depth) {
guidedSearch(origin, depth);
if (depth >= maxDepth) {
return;
}
Expr *curExpr = &origin.root;
nnet_assert((*curExpr)->getType() == NodeType::RangeOpNodeType, __LINE__);
while ((*curExpr)->getType() == NodeType::RangeOpNodeType) {
auto curRangeOp = as<RangeOpNode>(*curExpr);
checkOOB(curRangeOp);
auto summand = curRangeOp->getSummand();
if (summand->getType() == NodeType::SubscriptNodeType) {
auto subscriptOp = as<SubscriptNode>(summand);
if (rule4StageMerging(origin, depth, *curExpr)) {
return;
}
curExpr = subscriptOp->getObjectPtr();
nnet_assert(*curExpr != nullptr, __LINE__);
continue;
}
if (summand->getType() == NodeType::BinaryOpNodeType) {
if (cntAppliedRules[1] < 3)
rule1VariableSplit(origin, depth, *curExpr); // +1/0
rule2VariableMerging(origin, depth, *curExpr); // +1
if (cntAppliedRules[3] < 1)
rule3StageSplit(origin, depth, *curExpr); // +1
rule5RangeRelaxation(origin, depth, *curExpr); // 0
rule7DLT(origin, depth, *curExpr);
rule9RangeMagnify(origin, depth, *curExpr);
return;
}
nnet_unimplemented_halt();
}
// RangeOp curRangeOp;
// for (Expr *curExpr = &origin.root;
// curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
// checkOOB(curRangeOp);
// auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
// // isSimplyNested: a directly nested stage
// bool isSimplyNested = (subscript &&
// subscript->isRangeOpSubscripted()); if (rule4StageMerging(origin,
// depth, *curExpr))
// return;
// // For the next nested stage
// curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
// }
// int stage = 0;
// for (Expr *curExpr = &origin.root;
// curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
// stage++;
// // isSimplyNested: a directly nested stage
// auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
// bool isSimplyNested = (subscript &&
// subscript->isRangeOpSubscripted());
// // TODO recover it
// // permuteRangeOps(origin, depth, *curExpr);
// // extractSubexpression(origin, depth, *curExpr);
// rule4StageMerging(origin, depth, *curExpr);
// if (!isSimplyNested) {
// std::cout << "num stage: " << depth << " " << stage << std::endl;
// if (depth < 5) {
// rule1VariableSplit(origin, depth, *curExpr); // +1/0
// rule3StageSplit(origin, depth, *curExpr); // +1
// rule2VariableMerging(origin, depth, *curExpr); // +1
// rule5RangeRelaxation(origin, depth, *curExpr); // 0
// rule9RangeMagnify(origin, depth, *curExpr);
// }
// if (depth >= 5) {
// rule1VariableSplit(origin, depth, *curExpr); // +1/0
// rule3StageSplit(origin, depth, *curExpr); // +1
// rule2VariableMerging(origin, depth, *curExpr); // +1
// rule5RangeRelaxation(origin, depth, *curExpr); // 0
// rule6KenerlMatching(origin, depth, *curExpr); // -1
// rule7DLT(origin, depth, *curExpr); // +1
// rule8GuidedDLT(origin, depth, *curExpr); //
// rule9RangeMagnify(origin, depth, *curExpr);
// }
// }
// // For the next nested stage
// curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
// }
}
Derivator::Derivator(int maxDepth, bool enableHashPruning, LogMode logMode,
PassMode passMode)
: maxDepth(maxDepth), logMode(logMode), passMode(passMode),
enableHashPruning(enableHashPruning), cntAppliedRules(12) {}
int Derivator::getNumIntermediateStates() { return cntStates; }
void Derivator::guidedSearch(Formula &origin, int depth) {
if (origin.root->getType() == NodeType::TensorNodeType) {
auto tensor = as<TensorNode>(origin.root);
appendCanddiate(tensor, depth);
return;
}
Expr *expr = &origin.root;
nnet_assert((*expr)->getType() == NodeType::RangeOpNodeType, __LINE__);
while ((*expr)->getType() == NodeType::RangeOpNodeType) {
auto rangeOp = as<RangeOpNode>(*expr);
checkOOB(rangeOp);
auto summand = rangeOp->getSummand();
if (summand->getType() == NodeType::SubscriptNodeType) {
auto subscriptOp = as<SubscriptNode>(summand);
if (rule4StageMerging(origin, depth, *expr)) {
return;
}
expr = subscriptOp->getObjectPtr();
nnet_assert(*expr != nullptr, __LINE__);
continue;
}
if (summand->getType() == NodeType::BinaryOpNodeType) {
break;
}
nnet_unimplemented_halt();
}
if (searchState == 0) {
searchState = 1;
rule8GuidedDLT(origin, depth, *expr);
searchState = 0;
return;
}
if (searchState == 1) {
rule8GuidedDLT(origin, depth, *expr);
return;
}
if (searchState == 2) {
matchComputationKernel(origin, depth, *expr);
return;
}
if (searchState == 3) {
// Pack the remaining computation as a MemBoundOp
matchMemBoundKernel(origin, depth, origin.root);
return;
}
nnet_unimplemented_halt();
return;
}
void Derivator::ruleBasedDerivate(Formula &origin, int depth) {
// string StartDfs = "ruleBasedDerivate dep=" + std::to_string(depth) +
// ", targetOp=" +
// std::to_string(routineTypeToId(targetOp));
// dbg(StartDfs, origin);
auto tensor = as<TensorNode>(origin.root);
if (tensor) {
appendCanddiate(tensor, depth);
return;
}
if (depth >= (int)rulesOverall.size())
return;
RangeOp curRangeOp;
for (Expr *curExpr = &origin.root;
curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
checkOOB(curRangeOp);
auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
// isSimplyNested: a directly nested stage
bool isSimplyNested = (subscript && subscript->isRangeOpSubscripted());
if (rule4StageMerging(origin, depth, *curExpr))
return;
// For the next nested stage
curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
}
int stageDepth = 0;
for (Expr *curExpr = &origin.root;
curExpr && (curRangeOp = as<RangeOpNode>(*curExpr));) {
// isSimplyNested: a directly nested stage
auto subscript = as<SubscriptNode>(curRangeOp->getSummand());
bool isSimplyNested = (subscript && subscript->isRangeOpSubscripted());
stageDepth++;
for (int rule : rulesOverall[depth]) {
if (rule == 1)
rule1VariableSplit(origin, depth, *curExpr);
else if (!isSimplyNested) {
if (rule == 2)
rule2VariableMerging(origin, depth, *curExpr);
else if (rule == 3)
rule3StageSplit(origin, depth, *curExpr);
else if (rule == 5)
rule5RangeRelaxation(origin, depth, *curExpr);
else if (rule == 6)
rule6KenerlMatching(origin, depth, *curExpr);
else if (rule == 7)
rule7DLT(origin, depth, *curExpr);
else if (rule == 8)
rule8GuidedDLT(origin, depth, *curExpr);
else if (rule == 9)
rule9RangeMagnify(origin, depth, *curExpr);
}
}
// For the next nested stage
curExpr = (isSimplyNested) ? subscript->getObjectPtr() : nullptr;
}
for (int rule : rulesOverall[depth])
if (rule == 90 && stageDepth == 2) // HACK: for (T)Conv2gemm
rule90TwoStageElementWise(origin, depth, origin.root);
else if (rule == 91 && stageDepth >= 2) // HACK: for TConv2gemm
rule91MergeStagesWithSum(origin, depth, origin.root);
}
void Derivator::nextStep(Formula &origin, int depth, Expr &rCur, Expr newCur) {
// Count the number of searched states
++cntStates;
rCur.swap(newCur);
HashType formulaHash = HashVisitor().getHash(origin.root);
if (enableHashPruning) {
if (searchState != 2) {
if (visited.find(formulaHash) != visited.end()) {
rCur.swap(newCur);
return;
}
visited.emplace(formulaHash);
}
}
if (searchState > 0) {
guidedSearch(origin, depth);
} else {
searchedMaxDepth = max(searchedMaxDepth, depth + 1);
if (searchStrategy == Strategy::DFS ||
(searchStrategy == Strategy::RuleAndDFS &&
depth + 1 >= (ssize_t)rulesOverall.size()))
dfs(origin, depth + 1);
else
ruleBasedDerivate(origin, depth + 1);
}
rCur.swap(newCur);
}
void Derivator::ruleBasedDFS(Formula &origin, int depth, vector<int> _rules,
map<int, vector<Iterator>> _substituteRules,
bool searchAfterRules) {
SaveStateGuard guard(*this, origin.root, string("Init: ") + __FUNCTION__);
searchStrategy = (searchAfterRules) ? Strategy::RuleAndDFS : Strategy::Rule;
rulesOverall.clear();
for (auto i : _rules)
rulesOverall.push_back({i});
substituteRules = _substituteRules;
ruleBasedDerivate(origin, depth);
}
void Derivator::search(Formula &origin, int depth) {
SaveStateGuard guard(*this, origin.root, string("Init: ") + __FUNCTION__);
searchStrategy = Strategy::DFS;
dfs(origin, depth);
}
void Derivator::print() {
std::cout << "[RESULT] Derivator::results: " << candidates.size()
<< std::endl;
std::cout << "==== DFS candidates (" << candidates.size()
<< ")====" << std::endl;
for (const auto &f : candidates) {
std::cout << f.toReadable() << std::endl;
// dbg(f.bfsDepth, f.toReadable());
}
std::cout << "==== DFS log end ====" << std::endl;
}
string Formula::toReadable() const { return FullPrinterVisitor().print(root); }
void Derivator::rule1VariableSplit(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[1];
Rule1VariableSplit(*this).run(origin, depth, rCur);
--cntAppliedRules[1];
}
void Derivator::rule2VariableMerging(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[2];
Rule2VariableMerging(*this).run(origin, depth, rCur);
--cntAppliedRules[2];
}
void Derivator::rule3StageSplit(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[3];
Rule3StageSplit(*this).run(origin, depth, rCur);
--cntAppliedRules[3];
}
bool Derivator::rule4StageMerging(Formula &origin, int depth, Expr &rCur,
bool mergeStageWithCalc) {
++cntAppliedRules[4];
Rule4StageMerging pass(*this);
pass.setMergeStageWithCalc(mergeStageWithCalc);
pass.run(origin, depth, rCur);
--cntAppliedRules[4];
return pass.isSuccessful();
}
void Derivator::rule5RangeRelaxation(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[5];
Rule5RangeRelaxation(*this).run(origin, depth, rCur);
--cntAppliedRules[5];
}
void Derivator::rule6KenerlMatching(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[6];
Rule6KenerlMatching(*this).run(origin, depth, rCur);
--cntAppliedRules[6];
}
void Derivator::rule7DLT(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[7];
Rule7DLT(*this).run(origin, depth, rCur);
--cntAppliedRules[7];
}
void Derivator::rule8GuidedDLT(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[8];
Rule8GuidedDLT(*this).run(origin, depth, rCur);
--cntAppliedRules[8];
}
void Derivator::rule9RangeMagnify(Formula &origin, int depth, Expr &rCur) {
++cntAppliedRules[9];
Rule9RangeMagnify(*this).run(origin, depth, rCur);
--cntAppliedRules[9];
}
void Derivator::rule90TwoStageElementWise(Formula &origin, int depth,
Expr &rCur) {
Rule90TwoStageElementWise(*this).run(origin, depth, rCur);
}
void Derivator::rule91MergeStagesWithSum(Formula &origin, int depth,
Expr &rCur) {
Rule91MergeStagesWithSum(*this).run(origin, depth, rCur);
}
void Derivator::matchComputationKernel(Formula &origin, int depth, Expr &rCur) {
MatchComputationKernel(*this).run(origin, depth, rCur);
}
void Derivator::matchMemBoundKernel(Formula &origin, int depth, Expr &rCur) {
MatchMemBoundKernel(*this).run(origin, depth, rCur);
}
bool Derivator::stageCombination(MultiFormulas &origin, int depth) {
return (CompareMultiFormulasVisitor().compare(origin.roots));
}
Expr Derivator::mergeMemboundStages(VecExpr stages) {
auto nested = MergeMemboundMutator(stages).merge();
return nested;
}
void Derivator::appendCanddiate(const Tensor &tensor, int depth) {
// if (!CountRoutineVisitor().match(tensor, 1, 0, 3))
// return;
candidates.emplace_back(tensor, depth);
// dbg("!!!!!!!!!!!!!!!Success!!!!!!!!!!!!!!!");
if (enableEquivalenceCheck)
checkDerivationEquivalence();
// printIntermediateStates();
}
bool Derivator::checkOOB(const RangeOp &rangeOp, bool halt) {
// Skip check in NoLog mode
if (logMode == LogMode::NoLog)
return false;
bool hasOOB = CheckOOBVisitor().checkRangeOp(rangeOp);
if (hasOOB) {
printIntermediateStates();
dbg(FullPrinterVisitor().print(rangeOp));
if (halt)
nnet_assert(0, "Out Of Bound in index!");
}
return hasOOB;
}
string Derivator::newTensorName() {
return "T" + std::to_string(++nTensorNames);
}
Var Derivator::getNewVar() {
return make_ref<VarNode>("i" + std::to_string(++nIteratorNames));
}
void Derivator::pushIntermediateState(const Expr &expr) {
intermediateStates.emplace_back(CloneMutator().clone(expr));
};
void Derivator::pushRuleState(const string &state) {
ruleStates.emplace_back(state);
}
void Derivator::pushRuleMsg(const string &state) {
ruleMsgs.emplace_back(state);
}
void Derivator::popIntermediateState() { intermediateStates.pop_back(); }
void Derivator::popRuleState() { ruleStates.pop_back(); }
void Derivator::popRuleMsg() { ruleMsgs.pop_back(); }
RoutineType Derivator::getTargetOp() { return targetOp; }
void Derivator::setTargetOp(RoutineType _targetOp) { targetOp = _targetOp; }
int Derivator::getSearchState() { return searchState; }
void Derivator::setSearchState(int _searchState) { searchState = _searchState; }
void Derivator::printStatistics() {
printf("==== Derivator statistics ====\n");
printf("Max Depth = %d\n", maxDepth);
printf("searchStrategy = ");
if (searchStrategy == Strategy::DFS)
printf("DFS\n");
else if (searchStrategy == Strategy::Rule)
printf("Rule\n");
else if (searchStrategy == Strategy::RuleAndDFS)
printf("RuleAndDFS\n");
printf("enableHashPruning = %s\n", enableHashPruning ? "true" : "false");
printf("Reached Max Depth during search = %d\n", searchedMaxDepth);
printf("#Candidates = %lu\n", candidates.size());
printf("#Intermediate states = %d\n", cntStates);
printf("#Hashed intermediate states = %lu\n", visited.size());
printf("#Iteratos = %d\n", nIteratorNames);
printf("#Tensors = %d\n", nTensorNames);
}
void Derivator::setDumpFirstSuccess(const string &_logFnPrefix) {
setEquivalenceCheck();
logMode = LogMode::DumpFristCandiate;
logFnPrefix = _logFnPrefix;
}
void Derivator::printIntermediateStates() {
// Skip in NoLog mode
if (logMode == LogMode::NoLog)
return;
assert(intermediateStates.size() == ruleStates.size());
assert(intermediateStates.size() == ruleMsgs.size());
for (size_t i = 0; i < intermediateStates.size(); ++i) {
string msg = "=== Depth " + std::to_string(i) + " " + ruleStates[i] +
": " + ruleMsgs[i];
std::cout << msg << endl;
std::cout << FullPrinterVisitor().print(intermediateStates[i]) << endl;
if (logMode == LogMode::DumpFristCandiate) {
Serializer serializer;
serializer.serialize(intermediateStates[i],
logFnPrefix + to_string(i) + ".expr", msg);
}
}
for (size_t i = 0; i < intermediateStates.size(); ++i) {
if (auto cur = as<RangeOpNode>(intermediateStates[i]))
if (CheckOOBVisitor().checkRangeOp(cur)) {
printf("OOB detected depth=%lu\n", i);
}
}
if (logMode == LogMode::DumpFristCandiate) {
puts("Serializaiton finished.");
exit(0);
}
}
void Derivator::checkDerivationEquivalence() {
if (!checkExprsEquvivalence(intermediateStates)) {
nnet_assert(0, "Inequivalent derivation");
exit(1);
}
}
void Derivator::setEquivalenceCheck() { enableEquivalenceCheck = true; }
Derivator::PassMode Derivator::getPassMode() { return passMode; }
Derivator::LogMode Derivator::getLogMode() { return logMode; }
} // namespace nnet

131
src/nnet/dlt.cc Normal file
View File

@ -0,0 +1,131 @@
#include "nnet/dlt.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
#include "nnet/visitor.h"
#include <algorithm>
namespace nnet {
optional<Expr> DLT::apply(const RangeOp &rangeOp, const Subscript &subscript,
string newTensorName) {
const auto &tensor = as<TensorNode>(subscript->getObject());
assert(tensor != nullptr);
vector<int> shape0(tensor->getShape()), shape1;
const auto &paddings = tensor->getPaddings();
VecExpr index0 = subscript->getIndex(), index1;
// compute new shapes and index
for (const auto &opPtr : ops) {
if (auto op = as<DLTSplit>(opPtr)) {
assert(op->dim < (int)shape0.size());
for (int i = 0; i < (int)shape0.size(); ++i)
if (i != op->dim) {
shape1.emplace_back(shape0[i]);
index1.emplace_back(index0[i]);
} else {
assert(shape0[i] % op->factor == 0);
shape1.emplace_back(shape0[i] / op->factor);
shape1.emplace_back(op->factor);
if (const auto &opt =
splitIndex(index0[i], op->factor, rangeOp);
opt.has_value()) {
index1.emplace_back(opt->first);
index1.emplace_back(opt->second);
} else
return {};
}
} else if (auto op = as<DLTMerge>(opPtr)) {
assert(op->dim0 < (int)shape0.size());
assert(op->dim1 < (int)shape0.size());
for (int i = 0; i < (int)shape0.size(); ++i)
if (i == op->dim0) {
shape1.emplace_back(shape0[op->dim0] * shape0[op->dim1]);
index1.emplace_back(index0[op->dim0] * shape0[op->dim1] +
index0[op->dim1]);
} else if (i != op->dim1) {
shape1.emplace_back(shape0[i]);
index1.emplace_back(index0[i]);
}
} else if (auto op = as<DLTReorder>(opPtr)) {
if (op->dims.size() != shape0.size()) {
// TODO: input Reorder should have the same order with tensor
nnet_unimplemented_continue();
return {};
}
assert(op->dims.size() == shape0.size());
for (size_t i = 0; i < shape0.size(); ++i) {
shape1.emplace_back(shape0[op->dims[i]]);
index1.emplace_back(index0[op->dims[i]]);
}
}
for (const auto &index : index1) {
// Maybe there are bugs...
// assert(index != nullptr);
if (index == nullptr) {
std::cout << "Warning empty" << std::endl;
return {};
}
}
shape0.swap(shape1);
shape1.clear();
index0.swap(index1);
index1.clear();
}
for (auto &index : index0) {
// Maybe there are bugs...
assert(index != nullptr);
if (index == nullptr)
return {};
index = SimplifyExprVisitor().simplify(index);
}
// HACK DLT with paddings: transfer original paddings to the new one
vector<int> dltedPaddings =
(paddings.size() == shape0.size()) ? paddings : vector<int>{};
// TODO [necessary] build DLT source expr. Is OP-based DLT is good too?
// HACK [important] fix this fake tensor.
auto elementRoutine = make_ref<ElementWiseNode>(
// FIXME: implement transpose
// makeTensor(newTensorName + "_DLT", {}), vector<Tensor>{tensor},
// shape0);
makeTensor("__DLT", {}), vector<Tensor>{tensor}, shape0);
auto dltedTensor =
makeTensor(newTensorName, shape0, dltedPaddings, elementRoutine);
auto dltedSubscript = makeSubscript(dltedTensor, index0);
return optional<Expr>(std::in_place, dltedSubscript);
}
optional<pair<Expr, Expr>> DLT::splitIndex(Expr expr, int factor,
RangeOp rangeOp) {
auto strides = SimplifyExprVisitor().getStrides(expr);
Expr quotient, remainder;
for (const auto &[iter, stride] : strides) {
const auto &[var, range] = rangeOp->getVarRange(iter);
// Add new expr, dealing with the initial empty expr
auto addExpr = [](Expr &orig, const Expr &newExpr) {
if (!orig)
orig = newExpr;
else
orig = orig + newExpr;
};
if (abs(stride) >= factor) {
if (stride % factor)
return {};
addExpr(quotient, (stride / factor) * var);
} else {
if (stride * (range.second - range.first) > factor)
return {};
addExpr(remainder, stride * var);
}
}
return optional<pair<Expr, Expr>>(std::in_place, quotient, remainder);
}
void DLT::split(int dim, int factor) {
ops.emplace_back(make_ref<DLTSplit>(dim, factor));
}
void DLT::merge(int dim0, int dim1) {
ops.emplace_back(make_ref<DLTMerge>(dim0, dim1));
}
void DLT::reorder(vector<int> dims) {
ops.emplace_back(make_ref<DLTReorder>(dims));
}
} // namespace nnet

452
src/nnet/expr.cc Normal file
View File

@ -0,0 +1,452 @@
#include "nnet/expr.h"
#include "nnet/Visitor/GetTensorsVisitor.h"
namespace nnet {
string serializeVec(vector<Expr> v) {
if (v.empty())
return "[]";
return "[" +
std::accumulate(v.begin() + 1, v.end(), v[0]->toReadable(),
[](const string &a, Expr b) {
return a + ',' + b->toReadable();
}) +
"]";
}
string serializeVec(vector<Var> v) {
VecExpr vv;
for (const auto &a : v)
vv.emplace_back(a);
return serializeVec(vv);
}
std::ostream &operator<<(std::ostream &ios, const ExprNode &expr) {
ios << expr.toReadable();
return ios;
}
TensorNode::TensorNode(string _name, vector<int> _shape, vector<int> _paddings,
Routine _source)
: name(_name), shape(_shape), paddings(_paddings), source(_source) {
if (source && source->getExpr()) {
if (auto range = as<RangeOpNode>(source->getExpr()))
for (auto [iter, lr] : range->getLoopVarRanges())
nnet_assert(lr.first == 0 && lr.second > 0,
"Tensor dims should start from 0.");
}
if (paddings.size() == 0)
paddings = vector<int>(shape.size(), 0);
assert(paddings.size() == shape.size());
}
string TensorNode::toOutputShape() const {
return "shape=" + serializeVec(shape) + " pad=" + serializeVec(paddings);
}
string TensorNode::toReadable() const {
string ret = name;
string property = "<pad=";
bool hasPaddings = false;
for (size_t i = 0; i < paddings.size(); ++i) {
if (i > 0)
property += ",";
property += to_string(paddings[i]);
if (paddings[i])
hasPaddings = true;
}
property += ">";
return (hasPaddings) ? ret + property : ret;
}
int TensorNode::getData(const Ref<vector<int>> &data, const vector<int> &idx) {
assert(idx.size() == shape.size());
for (size_t i = 0; i < idx.size(); ++i) {
if (idx[i] < 0 || idx[i] >= shape[i]) {
assert(0 - paddings[i] <= idx[i]);
assert(idx[i] < shape[i] + paddings[i]);
return 0;
}
}
return data->at(getOffset(idx));
}
size_t TensorNode::getOffset(const vector<int> &idx) {
auto nDim = idx.size();
assert(shape.size() == nDim);
if (idx.empty()) {
return 0;
}
for (size_t i = 0; i < nDim; i++) {
if (idx[i] < 0 || shape[i] <= idx[i]) {
return (size_t)-1;
}
}
size_t offset = idx[0];
size_t dim = 0;
while (++dim < nDim) {
offset = offset * shape[dim] + idx[dim];
}
return offset;
}
string RangeOpNode::toReadable() const {
string ret;
for (int i = 0; i < IterationType::NumIterationType; ++i) {
ret += (i == Loop) ? "L" : "Sum";
for (const auto &kv : vars[i]) {
ret += "<" + kv.first->getName() + ":" +
std::to_string(kv.second.first) + ":" +
std::to_string(kv.second.second) + ">";
}
if (i == Loop && hasPaddings()) {
ret += "<pad=";
for (const auto &i : paddings) {
ret += to_string(i) + ",";
}
ret += ">";
}
}
if (auto sub = as<SubscriptNode>(getSummand()); sub) {
ret += " ... " + serializeVec(sub->getIndex()) + "\n {" +
sub->getObject()->toReadable() + "}";
} else {
ret += "\n {" + subExprs[Summand]->toReadable() + "}";
}
return ret;
};
int RangeOpNode::getNumOutputDims() const { return vars[Loop].size(); }
bool RangeOpNode::hasVar(int index, Var name) const {
for (const auto &kv : vars[index])
if (kv.first->equal(name))
return true;
return false;
}
int RangeOpNode::getVarIndex(int type, string name) {
for (size_t i = 0; i < vars[type].size(); ++i)
if (vars[type][i].first->equal(name))
return i;
assert(0);
return 0;
}
Range RangeOpNode::getRange(const Var &var) const {
for (const auto &varRanges : vars) {
for (const auto &varRange : varRanges) {
if (varRange.first->equal(var))
return varRange.second;
}
}
nnet_assert(0, "Var is not a iterator.");
return Range();
}
VarRangePair RangeOpNode::getVarRange(const Var &var) const {
for (const auto &varRanges : vars) {
for (const auto &varRange : varRanges) {
if (varRange.first->equal(var))
return varRange;
}
}
nnet_assert(0, "Var is not a iterator.");
return VarRangePair();
}
void SubscriptNode::setObject(Expr e) {
nnet_assert(as<TensorNode>(e) || as<RangeOpNode>(e),
"Illegal subscripted object");
indexed = e;
}
bool SubscriptNode::isRangeOpSubscripted() const {
return as<RangeOpNode>(indexed) != nullptr;
}
vector<Range> SubscriptNode::getObjectRangesWithoutPaddings() const {
if (isRangeOpSubscripted()) {
return as<RangeOpNode>(indexed)->getOutputRanges();
} else {
vector<Range> ret;
for (const auto &len : as<TensorNode>(indexed)->getShape())
ret.emplace_back(0, len);
return ret;
}
}
vector<Range> SubscriptNode::getObjectRangesWithPaddings() const {
vector<Range> ret;
if (isRangeOpSubscripted()) {
auto rangeOp = as<RangeOpNode>(indexed);
ret = rangeOp->getOutputRanges();
for (size_t i = 0; i < ret.size(); ++i) {
if (int pad = rangeOp->getPaddings(i)) {
ret[i].first -= pad;
ret[i].second += pad;
}
}
} else {
auto tensor = as<TensorNode>(indexed);
for (const auto &len : tensor->getShape())
ret.emplace_back(0, len);
for (int i = 0; i < tensor->getDims(); ++i) {
if (int pad = tensor->getPadding(i)) {
ret[i].first -= pad;
ret[i].second += pad;
}
}
}
return ret;
}
optional<pair<Iterator, int>> BinaryOpNode::getModDivParameter() const {
auto lhs = as<VarNode>(getLhs());
auto rhs = as<ConstantNode>(getRhs());
if (lhs == nullptr) {
return {};
}
if (lhs->getType() != NodeType::VarNodeType) {
nnet_unimplemented_halt();
}
if (rhs->getType() != NodeType::ConstantNodeType) {
nnet_unimplemented_halt();
}
assert(rhs != nullptr);
return pair(lhs, rhs->getValue());
}
pair<Expr, int> BinaryOpNode::getModDivExpr() const {
auto constant = as<ConstantNode>(getRhs());
assert(constant != nullptr);
return pair(getLhs(), constant->getValue());
}
string BinaryOpNode::toReadable() const {
string ret = "(";
ret += subExprs[LHS]->toReadable();
ret += " ";
ret += opSymbols[static_cast<std::underlying_type_t<OpType>>(opType)];
ret += " ";
ret += subExprs[RHS]->toReadable();
return ret + ")";
};
bool BinaryOpNode::isSwapable() const {
switch (getOpType()) {
case OpType::Add:
case OpType::Mul:
return true;
case OpType::Sub:
case OpType::Div:
case OpType::Mod:
return false;
default:
nnet_unimplemented_halt();
return false;
}
}
string SubscriptNode::toReadable() const {
string ret;
ret += "{";
ret += indexed->toReadable();
ret += "}[";
for (size_t i = 0; i < subExprs.size(); ++i) {
ret += subExprs[i]->toReadable();
if (i != subExprs.size() - 1)
ret += ", ";
else
ret += "]";
}
return ret;
};
string FuncNode::toReadable() const {
string ret;
if (funcType == FuncType::Relu)
ret += "Relu";
else if (funcType == FuncType::Tanh)
ret += "Tanh";
else
nnet_unimplemented_halt();
ret += "( ... " + serializeVec(object->getIndex()) + ")\n {" +
object->getObject()->toReadable() + "}";
return ret;
}
Expr operator+(const Expr &lhs, const Expr &rhs) {
if (lhs == nullptr && rhs == nullptr)
return nullptr;
else if (lhs == nullptr)
return rhs;
else if (rhs == nullptr)
return lhs;
else
return make_ref<BinaryOpNode>(OpType::Add, lhs, rhs);
}
BinaryOp operator-(const Expr &lhs, const Expr &rhs) {
return make_ref<BinaryOpNode>(OpType::Sub, lhs, rhs);
}
BinaryOp operator*(const Expr &lhs, const Expr &rhs) {
return make_ref<BinaryOpNode>(OpType::Mul, lhs, rhs);
}
BinaryOp operator/(const Expr &lhs, const Expr &rhs) {
return make_ref<BinaryOpNode>(OpType::Div, lhs, rhs);
}
BinaryOp operator%(const Expr &lhs, const Expr &rhs) {
return make_ref<BinaryOpNode>(OpType::Mod, lhs, rhs);
}
Expr operator+(const Expr &lhs, const int &rhs) {
if (lhs != nullptr && rhs != 0)
return make_ref<BinaryOpNode>(OpType::Add, lhs,
make_ref<ConstantNode>(rhs));
else if (lhs == nullptr)
return make_ref<ConstantNode>(rhs);
else
return lhs;
}
Expr operator+(const int &lhs, const Expr &rhs) { return rhs + lhs; }
Expr operator-(const Expr &lhs, const int &rhs) { return lhs + (-rhs); }
Expr operator-(const int &lhs, const Expr &rhs) {
if (rhs != nullptr)
return make_ref<BinaryOpNode>(OpType::Sub, make_ref<ConstantNode>(lhs),
rhs);
else
return make_ref<ConstantNode>(lhs);
}
Expr operator*(const Expr &lhs, const int &rhs) {
if (rhs == 1)
return lhs;
else
return make_ref<BinaryOpNode>(OpType::Mul, lhs,
make_ref<ConstantNode>(rhs));
}
Expr operator*(const int &lhs, const Expr &rhs) {
if (lhs == 1)
return rhs;
else
return make_ref<BinaryOpNode>(OpType::Mul, make_ref<ConstantNode>(lhs),
rhs);
}
bool operator==(const Var &lhs, const string &rhs) {
return lhs->getName() == rhs;
}
bool operator==(const string &lhs, const Var &rhs) { return rhs == lhs; }
Expr operator%(const Expr &lhs, const int rhs) {
return make_ref<BinaryOpNode>(OpType::Mod, lhs,
make_ref<ConstantNode>(rhs));
}
Expr operator/(const Expr &lhs, const int rhs) {
if (rhs == 1)
return lhs;
else
return make_ref<BinaryOpNode>(OpType::Div, lhs,
make_ref<ConstantNode>(rhs));
}
// Wrappers for type deduction
Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts) {
return make_ref<SubscriptNode>(tensor, subscripts);
}
RangeOp makeRangeOperator(const vector<VarRangePair> &_loopIters,
const vector<VarRangePair> &_sumIters, Expr _summand,
const vector<int> &paddings) {
return make_ref<RangeOpNode>(_loopIters, _sumIters, _summand, paddings);
}
// Wrappers for type deduction
Tensor makeTensor(const string &name, const vector<int> &shape,
const vector<int> &paddings, const Routine &source) {
if (paddings.size() == 0)
return make_ref<TensorNode>(name, shape,
vector<int>((int)shape.size(), 0), source);
else
return make_ref<TensorNode>(name, shape, paddings, source);
}
int64_t TensorNode::getSize() const {
int64_t size = 1;
for (auto len : shape)
size *= len;
return size;
}
int RangeOpNode::getPaddings(int dim) const {
return dim < (int)paddings.size() ? paddings[dim] : 0;
}
vector<int> RangeOpNode::getPaddings() const {
if (paddings.size() > 0)
return paddings;
else
return vector<int>(getNumOutputDims(), 0);
}
void RangeOpNode::setPaddings(vector<int> _paddings) { paddings = _paddings; }
bool RangeOpNode::hasPaddings() const {
for (const auto &p : paddings)
if (p > 0)
return true;
return false;
}
int64_t RangeOpNode::getFlops() const {
int64_t sumCnt = getOutputSize();
if (vars[Sum].empty())
sumCnt = 0;
else
for (const auto &[var, range] : getSumVarRanges())
sumCnt *= range.second - range.first;
return sumCnt;
}
int64_t RangeOpNode::getInputSize(const RangeOp &self) const {
int64_t ret = 0;
auto tensors = GetTensorsVisitor().get(self);
for (const auto &[name, tensor] : tensors)
ret += tensor->getSize();
return ret;
}
int64_t RangeOpNode::getOutputSize() const {
int64_t loopCnt = 1;
for (const auto &[var, range] : getLoopVarRanges())
loopCnt *= range.second - range.first;
return loopCnt;
}
vector<int> RangeOpNode::getOutputShape() const {
vector<int> ret;
for (const auto &[var, range] : getLoopVarRanges())
ret.emplace_back(range.second - range.first);
return ret;
}
vector<Range> RangeOpNode::getOutputRanges() const {
vector<Range> ret;
for (const auto &[var, range] : getLoopVarRanges())
ret.emplace_back(range);
for (size_t i = 0; i < paddings.size(); ++i)
if (paddings[i] > 0) {
ret[i].first -= paddings[i];
ret[i].second += paddings[i];
}
return ret;
}
void FuncNode::setObject(Expr e) {
object = as<SubscriptNode>(e);
nnet_assert(object, "Illegal subscripted object");
}
} // namespace nnet

653
src/nnet/iterator_table.cc Normal file
View File

@ -0,0 +1,653 @@
#include "nnet/iterator_table.h"
#include "nnet/Visitor/MatchTableVisitor.h"
#include "nnet/Visitor/SimplifyExprVisitor.h"
#include "nnet/permutation.h"
#include <iostream>
namespace nnet {
bool IteratorTable::analyzeExpr(const RangeOp &_rangeOp) {
rangeOp = _rangeOp;
MatchTableVisitor mtVisitor;
if (!mtVisitor(rangeOp))
return false;
tie(appearance, tensors, strideTable, subscripts) = mtVisitor.getResult();
// dbg(appearance, tensors, strideTable);
return true;
}
// mapTensors
void IteratorTable::buildTable(const vector<int> &_tensorMap) {
tensorMap = _tensorMap;
tensorMap_r.clear();
tensorMap_r.resize(getNumInputs());
tensorIDMap_r.clear();
tensorIDMap_r.resize(getNumInputs());
posTable.clear();
posTable.resize(getNumRows());
strideInTensor.clear();
// build reversed index (from tensorID to tensor/original tensor index)
for (size_t i = 0; i < tensorMap.size(); ++i) {
tensorMap_r[tensorMap[i]] = tensors[i];
tensorIDMap_r[tensorMap[i]] = i;
}
strideInDim.clear();
strideInDim.resize(getNumInputs());
for (int i = 0; i < getNumInputs(); ++i)
strideInDim[i].resize(getTensor(i)->getDims());
// auxiliary array for calculate in-dim stride
vector<vector<int>> ldaInTensors(getNumInputs());
for (int i = 0; i < getNumInputs(); ++i) {
ldaInTensors[i].resize(getTensor(i)->getDims());
ldaInTensors[i].back() = 1;
for (int j = getTensor(i)->getDims() - 2; j >= 0; --j)
ldaInTensors[i][j] =
ldaInTensors[i][j + 1] * getTensor(i)->getShape(j + 1);
}
map<TensorNode *, int> inputTensor2id;
for (int i = 0; i < getNumInputs(); ++i)
inputTensor2id[tensors[i].get()] = tensorMap[i];
iterInTensorDim.clear();
iterInTensorDim.resize(getNumInputs());
for (int i = 0; i < getNumInputs(); ++i) {
iterInTensorDim[tensorMap[i]].resize(tensors[i]->getDims());
}
for (const auto &[var, tds] : strideTable) {
int bitmap = 0;
strideInTensor.emplace(var, getNumInputs());
for (const auto &[tensorNode, dim, stride] : tds) {
int tensorID = inputTensor2id[tensorNode];
int bit = 1 << tensorID;
if (std::find_if(iterInTensorDim[tensorID][dim].begin(),
iterInTensorDim[tensorID][dim].end(),
[var = var](const Var &v) {
return v->equal(var);
}) == iterInTensorDim[tensorID][dim].end())
iterInTensorDim[tensorID][dim].emplace_back(var);
bitmap |= bit;
if (strideInTensor[var][tensorID] < 0 || stride < 0)
strideInTensor[var][tensorID] = -1;
else
strideInTensor[var][tensorID] += stride;
// Update strideInDim
assert(stride % ldaInTensors[tensorID][dim] == 0);
strideInDim[tensorID][dim][var] =
stride / ldaInTensors[tensorID][dim];
}
if (rangeOp->hasLoopVar(var))
bitmap |= 1 << getNumInputs();
posTable[bitmap].emplace_back(var);
}
}
void IteratorTable::buildTableWithDefaultMap() {
vector<int> tensorMap;
for (int i = 0; i < getNumInputs(); ++i)
tensorMap.emplace_back(i);
buildTable(tensorMap);
}
int IteratorTable::getStridesInTensor(Iterator iter, int tensorID) const {
return strideInTensor.at(iter).at(tensorID);
}
vector<int> IteratorTable::getIterDimInTensor(int tensorID,
const Iterator &iter) const {
vector<int> ret;
for (size_t i = 0; i < iterInTensorDim[tensorID].size(); ++i) {
for (const auto &it : iterInTensorDim[tensorID][i])
if (iter->equal(it))
ret.emplace_back(i);
}
return ret;
}
vector<Mismatch> IteratorTable::matchPatternIT(const Pattern &patternIT) {
vector<Mismatch> ret;
iterMap.clear();
vector<vector<Iterator>> multiExprVar, multiPatternVar;
// match iterators in single iterator rows
for (int row = 0; row < getNumRows(); ++row) {
int nExprVars = posTable[row].size(),
nPatternVars = patternIT.posTable[row].size();
if (nExprVars < nPatternVars) {
ret.emplace_back(MismatchType::LessVar, row);
continue;
}
if (nExprVars > nPatternVars) {
ret.emplace_back(MismatchType::MoreVar, row);
continue;
}
if (posTable[row].empty())
continue;
// prepare for many-to-many iterator mapping
if (posTable[row].size() > 1) {
multiExprVar.emplace_back(posTable[row]);
multiPatternVar.emplace_back(patternIT.posTable[row]);
}
assert(!iterMap.count(posTable[row][0])); // check NO duplicate mapping
if (posTable[row].size() == 1)
iterMap[posTable[row][0]] = patternIT.posTable[row][0];
}
if (!ret.empty())
return ret;
PermutationGenerator permutationGenerator{multiPatternVar, multiExprVar};
bool checked = false;
// Permute iterator mappings to find a matched case
do {
auto mappingIter_r = permutationGenerator.get();
for (const auto &[exprIter, patternIter] : iterMap)
mappingIter_r[patternIter] = exprIter;
auto mismatches = matchPatternITCheckStrides(patternIT, mappingIter_r);
// if (mappingIter_r.count("_Conv_c"))
// if (mappingIter_r["_Conv_n"] == "n" &&
// mappingIter_r["_Conv_c"] == "c" &&
// mappingIter_r["_Conv_h"] == "i22" &&
// mappingIter_r["_Conv_r"] == "i4" &&
// mappingIter_r["_Conv_w"] == "i17" &&
// mappingIter_r["_Conv_s"] == "i14") {
// dbg(ret.size());
// if (mismatches.size() > 0)
// dbg(mismatches.size(), mismatches[0].type);
// }
if (mismatches.size() == 0) { // matched
ret = mismatches;
// Complete iterator mapping
for (const auto &[patternIter, exprIter] : mappingIter_r) {
if (iterMap.count(exprIter))
assert(iterMap[exprIter]->equal(patternIter));
iterMap[exprIter] = patternIter;
}
break;
} else if (!checked) {
ret = mismatches;
checked = true;
} else if ((static_cast<int>(ret[0].type) <
static_cast<int>(mismatches[0].type)) ||
((static_cast<int>(ret[0].type) ==
static_cast<int>(mismatches[0].type)) &&
(mismatches.size() < ret.size()))) {
ret = mismatches;
}
} while (permutationGenerator.next());
// Build reverse iterator mapping
if (ret.empty()) {
iterToRange_r.clear();
for (const auto &[exprIter, patternIter] : iterMap)
iterToRange_r[patternIter] = rangeOp->getVarRange(exprIter);
}
return ret;
}
vector<Mismatch> IteratorTable::matchPatternITCheckStrides(
const Pattern &patternIT, PtrMap<Iterator, Iterator> mappingIter_r) {
vector<Mismatch> ret;
// Check strides against each stride constraint
for (const auto &constraint : patternIT.getStrideConstraints()) {
// TODO: supprot PROPOTIONAL constraint
auto stride0 = strideInTensor.at(
mappingIter_r[constraint.v0])[constraint.tensorID];
auto stride1 = strideInTensor.at(
mappingIter_r[constraint.v1])[constraint.tensorID];
if (stride0 != stride1) {
ret.emplace_back(Mismatch(MismatchType::StrideMismatch, -1));
}
}
if (!ret.empty())
return ret;
// check the appearance of iterators inside tensors.
// If mismatch, this can be repaired by guided DLT.
for (int tensorID = 0; tensorID < getNumInputs(); ++tensorID) {
int exprTensorDim = tensorMap_r[tensorID]->getDims();
int patternTensorDim = patternIT.tensorMap_r[tensorID]->getDims();
if (exprTensorDim != patternTensorDim) {
ret.emplace_back(MismatchType::DLMismatch, tensorID);
continue;
}
[&] {
for (int dim = 0; dim < exprTensorDim; ++dim) {
// If #iters is differnt, than DLT is required
if (strideInDim[tensorID][dim].size() !=
patternIT.strideInDim[tensorID][dim].size()) {
ret.emplace_back(MismatchType::DLMismatch, tensorID);
return;
}
for (const auto &[patternIter, patternStride] :
patternIT.strideInDim[tensorID][dim]) {
auto exprIter = mappingIter_r[patternIter];
// If iters are differnt
if (!strideInDim[tensorID][dim].count(exprIter)) {
ret.emplace_back(MismatchType::DLMismatch, tensorID);
return;
}
auto exprStride = strideInDim[tensorID][dim].at(exprIter);
// TODO: for stride and dilation
if (exprStride != patternStride) {
ret.emplace_back(MismatchType::DLMismatch, tensorID);
return;
}
}
}
}();
}
if (!ret.empty())
return ret;
// check output data layout
// Output dim mismatch is not implemented.
if (patternIT.rangeOp->getNumOutputDims() != rangeOp->getNumOutputDims()) {
ret.emplace_back(Mismatch{MismatchType::OutputDimismatch, 0});
return ret;
}
for (size_t i = 0; i < rangeOp->getLoopVarRanges().size(); ++i) {
if (!mappingIter_r[patternIT.rangeOp->getLoopVar(i)]->equal(
rangeOp->getLoopVar(i))) {
ret.emplace_back(MismatchType::OutputDLMismatch, getNumInputs(),
mappingIter_r);
break;
}
}
return ret;
}
pair<PatternTensorMap, PatternIterRangeMap>
IteratorTable::getReverseMap() const {
return {tensorMap_r, iterToRange_r};
}
bool Pattern::isAllUniqueAccess(int tensorID) const {
for (const auto &iterInDim : iterInTensorDim[tensorID]) {
if (iterInDim.size() != 1)
return false;
}
return true;
}
Expr MatmulPattern::buildExpr(
const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges, string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const {
// TODO support b
assert(tensors.size() == 2);
int b = 1;
int m = tensors[0]->getShape(0), n = tensors[1]->getShape(0);
int k = tensors[0]->getShape(1);
// TODO: check strides
// TODO: DLT for output?
// FIXME: check the trans
auto matmul = make_ref<MatmulNode>(expr, tensors[0], tensors[1], b, m, n, k,
false, true);
auto output = make_ref<TensorNode>(outputName, vector<int>{m, n},
vector<int>{0, 0}, matmul);
return output;
}
const Pattern &MatmulPattern::getMatmulPattern() {
static class MatmulPattern exprIT;
static bool inited = false;
if (!inited) {
inited = true;
int M = 224, N = 8, K = 16;
auto m = make_ref<VarNode>("_Matmul_m");
auto n = make_ref<VarNode>("_Matmul_n");
auto k = make_ref<VarNode>("_Matmul_k");
auto A = make_ref<TensorNode>("_Matmul_A", vector<int>({M, K}));
auto B = make_ref<TensorNode>("_Matmul_B", vector<int>({N, K}));
auto subA = makeSubscript(A, {m, k});
auto subB = makeSubscript(B, {n, k});
auto range = makeRangeOperator({{m, {0, M}}, {n, {0, N}}},
{{k, {0, K}}}, subA * subB);
auto success = exprIT.analyzeExpr(range);
assert(success);
exprIT.buildTable({0, 1});
}
return exprIT;
}
const Pattern &ConvPattern::getPattern() {
static class ConvPattern exprIT;
static bool inited = false;
if (!inited) {
inited = true;
// The shape is meaningless but cannot be zero IT building
int N = 8, C = 16, H = 224, W = 224, F = 16, R = 3, S = 3;
// auto n = make_ref<VarNode>("_Matmul_n");
auto A = make_ref<TensorNode>("_Conv_A", vector<int>({N, C, H, W}));
auto B = make_ref<TensorNode>("_Conv_K", vector<int>({F, C, R, S}));
auto subA = makeSubscript(A, {n, c, h + r, w + s});
auto subB = makeSubscript(B, {f, c, r, s});
auto range = makeRangeOperator(
{{n, {0, 0}}, {f, {0, 0}}, {h, {0, 0}}, {w, {0, 0}}},
{{c, {0, 0}}, {r, {0, 0}}, {s, {0, 0}}}, subA * subB);
auto success = exprIT.analyzeExpr(range);
assert(success);
exprIT.buildTable({0, 1});
}
return exprIT;
}
Expr ConvPattern::buildExpr(
const Expr &expr, const vector<Tensor> &tensors,
const PatternIterRangeMap &varRanges, string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const {
// calculate paddings
const auto &rangeH = varRanges.at(h).second;
const auto &rangeR = varRanges.at(r).second;
const auto &rangeW = varRanges.at(w).second;
const auto &rangeS = varRanges.at(s).second;
auto offsetH =
SimplifyExprVisitor().getConstant(exprIT.getSubscript(0)->getIndex(2));
auto offsetW =
SimplifyExprVisitor().getConstant(exprIT.getSubscript(0)->getIndex(3));
int ph = calcPadding(tensors[0], 2, rangeH, rangeR, offsetH);
int pw = calcPadding(tensors[0], 3, rangeW, rangeS, offsetW);
// TODO strided, dilated
auto conv = make_ref<ConvNode>(expr, tensors[0], tensors[1], ph, pw);
auto shape = conv->getShape();
auto rangeOpShape = as<RangeOpNode>(expr)->getOutputShape();
assert(shape.size() == rangeOpShape.size());
dbg(shape, rangeOpShape);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != rangeOpShape[i]) {
dbg("Warning: unmatched Conv output", shape, rangeOpShape);
return nullptr;
}
}
auto output =
make_ref<TensorNode>(outputName, shape, vector<int>{0, 0, 0, 0}, conv);
return output;
}
RangeOp IteratorTable::getRangeOp() const { return rangeOp; }
#define DEF_CONV_VAR(a) \
const Var ConvPattern::a = make_ref<VarNode>("_Conv_" #a)
DEF_CONV_VAR(n);
DEF_CONV_VAR(c);
DEF_CONV_VAR(h);
DEF_CONV_VAR(w);
DEF_CONV_VAR(f);
DEF_CONV_VAR(r);
DEF_CONV_VAR(s);
#undef DEF_CONV_VAR
int Pattern::calcPadding(const Tensor &tensor, int dim, Range rangeH,
Range rangeR, int offset) const {
int l = rangeH.first + rangeR.first + offset;
int r = rangeH.second + rangeR.second - 1 + offset;
int ret = max(0, max(0 - l, r - tensor->getShape(dim)));
// dbg(l, r, rangeH, rangeR, offset, ret, tensor->getPadding(dim));
// check OutOfBound
assert(ret <= tensor->getPadding(dim));
return ret;
}
#define DEF_SG2BMM_VAR(a) \
const Var Sg2bmmPattern::a = make_ref<VarNode>("_Sg2bmm_" #a)
DEF_SG2BMM_VAR(b);
DEF_SG2BMM_VAR(m);
DEF_SG2BMM_VAR(w);
DEF_SG2BMM_VAR(k);
#undef DEF_SG2BMM_VAR
const Pattern &Sg2bmmPattern::getPattern() {
static class Sg2bmmPattern exprIT;
static bool inited = false;
if (!inited) {
inited = true;
// The shape is meaningless but cannot be zero IT building
int Batch = 8, M = 32, K = 224, W = 2;
// auto n = make_ref<VarNode>("_Matmul_n");
auto A = make_ref<TensorNode>("_Sg2bmm_A", vector<int>{Batch, M, K});
auto B = make_ref<TensorNode>("_Sg2bmm_B", vector<int>{Batch, M, K});
auto subA = makeSubscript(A, {b, m, k});
auto subB = makeSubscript(B, {b, m + w, k});
auto range =
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {-W, W + 1}}},
{{k, {0, K}}}, subA * subB);
auto success = exprIT.analyzeExpr(range);
assert(success);
exprIT.buildTableWithDefaultMap();
}
return exprIT;
}
Expr Sg2bmmPattern::buildExpr(
const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges, string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const {
// calculate paddings
assert(tensors.size() == 2);
assert(tensors[0]->getDims() == 3 && tensors[1]->getDims() == 3);
int Batch = tensors[0]->getShape(0);
int M = tensors[0]->getShape(1);
int K = tensors[0]->getShape(2);
int W = getLength(varRanges.at(w).second) / 2;
auto op = make_ref<G2bmmNode>(expr, tensors[0], tensors[1], Batch, M, W, K);
auto shape = op->getShape();
auto rangeOpShape = as<RangeOpNode>(expr)->getOutputShape();
assert(shape.size() == rangeOpShape.size());
for (size_t i = 0; i < shape.size(); ++i) {
assert(shape[i] == rangeOpShape[i]);
}
auto output =
make_ref<TensorNode>(outputName, shape, vector<int>{0, 0, 0}, op);
return output;
}
#define DEF_LongformerGBMM_VAR(a) \
const Var LongformerGBMMPattern::a = make_ref<VarNode>("_lo_" #a)
DEF_LongformerGBMM_VAR(b);
DEF_LongformerGBMM_VAR(m);
DEF_LongformerGBMM_VAR(w);
DEF_LongformerGBMM_VAR(n);
#undef DEF_LongformerGBMM_VAR
const Pattern &LongformerGBMMPattern::getPattern() {
static class LongformerGBMMPattern exprIT;
static bool inited = false;
if (!inited) {
inited = true;
// The shape is meaningless but cannot be zero IT building
int Batch = 8, M = 32, N = 224, W = 2;
auto A =
make_ref<TensorNode>("_lo_A", vector<int>{Batch, M, 2 * W + 1});
auto B = make_ref<TensorNode>("_lo_B", vector<int>{Batch, M, N});
auto subA = makeSubscript(A, {b, m, w});
auto subB = makeSubscript(B, {b, m + w, n});
auto range =
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, M}}},
{{w, {-W, W + 1}}}, subA * subB);
auto success = exprIT.analyzeExpr(range);
assert(success);
exprIT.buildTableWithDefaultMap();
}
return exprIT;
}
Expr LongformerGBMMPattern::buildExpr(
const Expr &expr, const vector<Tensor> &tensors,
[[maybe_unused]] const PatternIterRangeMap &varRanges, string outputName,
[[maybe_unused]] const IteratorTable &exprIT) const {
// calculate paddings
assert(tensors.size() == 2);
assert(tensors[0]->getDims() == 3 && tensors[1]->getDims() == 3);
int Batch = tensors[0]->getShape(0);
int M = tensors[0]->getShape(1);
assert(tensors[0]->getShape(2) % 2 == 1);
int W = tensors[0]->getShape(2) / 2;
int N = tensors[1]->getShape(2);
auto op = make_ref<GbmmNode>(expr, tensors[0], tensors[1], Batch, M, W, N);
auto shape = op->getShape();
auto rangeOpShape = as<RangeOpNode>(expr)->getOutputShape();
assert(shape.size() == rangeOpShape.size());
for (size_t i = 0; i < shape.size(); ++i) {
assert(shape[i] == rangeOpShape[i]);
}
auto output = make_ref<TensorNode>(outputName, shape,
vector<int>(shape.size(), 0), op);
return output;
}
const Pattern &getPattern(RoutineType targetOp) {
switch (targetOp) {
case RoutineType::MatmulNodeType:
return MatmulPattern::getMatmulPattern();
case RoutineType::ConvNodeType:
return ConvPattern::getPattern();
case RoutineType::G2bmmNodeType:
return Sg2bmmPattern::getPattern();
case RoutineType::GbmmNodeType:
return LongformerGBMMPattern::getPattern();
default:
nnet_unimplemented_halt();
}
}
string getPatternName(RoutineType targetOp) {
switch (targetOp) {
case RoutineType::MatmulNodeType:
return "Matmul";
case RoutineType::ConvNodeType:
return "Conv";
case RoutineType::G2bmmNodeType:
return "G2bmm";
case RoutineType::GbmmNodeType:
return "Gbmm";
default:
nnet_unimplemented_halt();
}
return {};
}
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
Expr ConvPattern::getExpr(Tensor A, Tensor K, int N, int C, int H, int W, int F,
int R, int S) {
DEFINE_VAR(n);
DEFINE_VAR(c);
DEFINE_VAR(h);
DEFINE_VAR(w);
DEFINE_VAR(f);
DEFINE_VAR(r);
DEFINE_VAR(s);
auto subA = makeSubscript(A, {n, c, h + r - R / 2, w + s - S / 2});
auto subB = makeSubscript(K, {f, c, r, s});
auto range =
makeRangeOperator({{n, {0, N}}, {f, {0, F}}, {h, {0, H}}, {w, {0, W}}},
{{c, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA * subB);
return range;
}
// Warn: F is the number of input channels, which is inversed compared with
// normal Conv.
// Input / output layouts: NHWF -> NHWC, Kenrel: RSFC
Expr ConvTransPattern::getExpr(Tensor A, Tensor K, int N, int C, int H, int W,
int F, int R, int S) {
const int padding = 1 * (R - 1) - 1;
assert(A->getPadding(2) == padding);
assert(R == 4);
assert(S == 4);
const int OH = 2 * H, OW = 2 * W;
DEFINE_VAR(n);
DEFINE_VAR(c);
DEFINE_VAR(f);
DEFINE_VAR(r);
DEFINE_VAR(s);
DEFINE_VAR(x1);
DEFINE_VAR(x2);
DEFINE_VAR(y1);
DEFINE_VAR(y2);
DEFINE_VAR(i2);
DEFINE_VAR(i4);
DEFINE_VAR(h);
DEFINE_VAR(w);
// dilation * (kernel_size - 1) - padding
// auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
// vector<int>{0, padding, padding, 0});
// auto K = make_ref<TensorNode>("K", vector<int>({R, S, F, C}));
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK =
makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c});
// x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2
auto range1 = makeRangeOperator(
{
{n, {0, N}},
{c, {0, C}},
{x1, {0, OH / 2 + 1}},
{x2, {0, 2}},
{y1, {0, OW / 2 + 1}},
{y2, {0, 2}},
},
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
auto sub0 = makeSubscript(
range1, {n, c, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2});
auto range0 = makeRangeOperator(
{{n, {0, N}}, {h, {0, OH}}, {w, {0, OW}}, {c, {0, C}}}, {}, sub0);
return range0;
}
pair<Expr, pair<Tensor, Tensor>> Sg2bmmPattern::getExpr(int Batch, int M, int K,
int W, int D) {
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
vector<int>{0, D * W, 0});
auto subA = makeSubscript(A, {b, m, k});
auto subB = makeSubscript(B, {b, m + D * (w - W), k});
auto range =
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {0, 2 * W + 1}}},
{{k, {0, K}}}, subA * subB);
return {range, {A, B}};
}
pair<Expr, pair<Tensor, Tensor>>
LongformerGBMMPattern::getExpr(int Batch, int M, int W, int K, int dilation) {
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, 2 * W + 1}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
vector<int>{0, dilation * W, 0});
auto subA = makeSubscript(A, {b, m, w});
auto subB = makeSubscript(B, {b, m + dilation * w - dilation * W, n});
auto range = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, K}}},
{{w, {0, 2 * W + 1}}}, subA * subB);
return {range, {A, B}};
}
pair<Expr, pair<Tensor, Tensor>> MatmulPattern::getExpr(bool transA,
bool transB, int Batch,
int M, int N, int K) {
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(n);
DEFINE_VAR(k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, K, N}),
vector<int>{0, 0, 0});
auto subA = makeSubscript(A, {b, m, k});
auto subB = makeSubscript(B, {b, k, n});
auto range = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, N}}},
{{k, {0, K}}}, subA * subB);
return {range, {A, B}};
}
#undef DEFINE_VAR
} // namespace nnet

722
src/nnet/nmutator.cc Normal file
View File

@ -0,0 +1,722 @@
#include "nnet/nmutator.h"
#include "nnet/Visitor/FullPrinterVisitor.h"
#include "nnet/Visitor/GetTensorsVisitor.h"
#include "nnet/Visitor/MatchReshapeVisitor.h"
#include "nnet/derivator.h"
#ifdef ABC
namespace infini {
NMutator::NMutator() {}
NMutator::NMutator(const std::vector<int> &derivationRules)
: mode{Mode::RuleBased}, derivationRules{derivationRules} {}
NMutator::~NMutator() {}
void NMutator::setToNaiveMembound() { mode = Mode::ToNaiveMembound; }
vector<Graph> NMutator::run(const Graph &in_graph) {
vector<Graph> out_graphs;
// Test helper: naively transform one Op to Membound
if (mode == Mode::ToNaiveMembound) {
runSingleOpToNaiveMembound(in_graph, out_graphs);
dbg(out_graphs.size());
return;
}
// // Hack for HetConv fusion
// if (statGraph(in_graph) == NMutator::SGType::HetConv) {
// dbg("Start fuse HetConv");
// out_graphs.emplace_back(fuseHetConv(nullptr, in_graph));
// }
// Clear input names maps with tensor
inputsNameNToTensorT.clear();
std::vector<Operator> computeOps;
in_graph->getComputeOps(computeOps);
// assert(computeOps.size() == 1);
if (computeOps.size() == 1)
runSingleOp(in_graph, out_graphs);
// FIXME: runMultipleOps results in segfault
// else
// runMultipleOps(in_graph, out_graphs);
}
void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
std::vector<Graph> &out_graphs) {
std::vector<Operator> computeOps;
in_graph->getComputeOps(computeOps);
assert(computeOps.size() == 1);
const auto &computeOp = computeOps[0];
auto g = std::make_shared<infini::Graph>();
auto expr = opToExpression(computeOp);
auto inputsN = nnet::GetTensorsVisitor().get(expr);
dbg(inputsN);
g->membound(computeOp->getInputs(), computeOp->getOutputs(),
{inputsN.at("A"), inputsN.at("K")}, expr, 0);
auto subgraph = new Graph(g->getOperators());
subgraph->print();
out_graphs.emplace_back(subgraph);
}
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
std::vector<Operator> computeOps;
in_graph->getComputeOps(computeOps);
if (infini::Graph g = transformTConv1x1(computeOps[0])) {
Graph graph = new Graph(g->getOperators());
out_graphs.emplace_back(graph);
return;
}
// Commented for debug, not implemented yet
// if (infini::Graph g = transformTConv3x3(computeOps[0])) {
// Graph graph = new Graph(g->getOperators());
// out_graphs.emplace_back(graph);
// return;
// }
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
Graph graph = new Graph(g->getOperators());
out_graphs.emplace_back(graph);
return;
}
// if (infini::Graph g = transformConv1x1(computeOps[0])) {
// Graph graph = new Graph(g->getOperators());
// out_graphs.emplace_back(graph);
// return;
// }
// if (infini::Graph g = transformConv1xk(computeOps[0])) {
// Graph graph = new Graph(g->getOperators());
// out_graphs.emplace_back(graph);
// return;
// }
auto expr = opToExpression(computeOps[0]);
if (!expr)
return;
nnet::Derivator derivator(maxDepth);
nnet::Formula conv_9x9(expr, 0);
// const std::vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}; // Tconv
// const std::vector<int> rules{1, 7, 7, 2, 8, 6, 6}; // G2BMM
if (mode == Mode::Normal) {
derivator.search(conv_9x9, 0);
} else if (mode == Mode::RuleBased) {
dbg(derivationRules);
derivator.ruleBasedDFS(conv_9x9, 0, derivationRules);
} else
nnet_assert(0, "Unknown mode");
const auto &candidates = derivator.getCandidates();
dbg(candidates.size());
// derivator.print();
for (const auto &candidate : candidates) {
// dbg(nnet::FullPrinterVisitor().print(candidate.root));
if (auto g = expressionToGraph(candidate.root, in_graph)) {
Graph graph = new Graph(g->getOperators());
out_graphs.emplace_back(graph);
}
// break; // HACK:Debug only for the first subgraph
}
// dbg(out_graphs);
// for (auto graph : out_graphs) {
// graph->print();
// }
cntStates += derivator.getNumIntermediateStates();
cntCandidates += derivator.getNumCandidates();
}
void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
std::cout << "run multiple ops" << std::endl;
in_graph->print();
std::cout << std::endl;
std::vector<Operator> computeOps;
dbg(computeOps);
in_graph->getComputeOps(computeOps);
nnet::VecExpr exprs;
for (const auto &op : computeOps)
exprs.emplace_back(opToExpression(op));
dbg(exprs);
nnet::Derivator derivator;
nnet::MultiFormulas origin(exprs, 0);
bool canCombine = derivator.stageCombination(origin, 0);
dbg(canCombine);
const auto matmul0 = dynamic_cast<MatmulOp *>(computeOps[0]);
assert(matmul0);
// Build merged graph
auto g = new infini::Graph();
std::vector<Tensor *> inputsT, weightsT, outputsT;
for (const auto &opT : computeOps) {
inputsT.emplace_back(opT->getInputs(0));
weightsT.emplace_back(opT->getInputs(1));
outputsT.emplace_back(opT->getOutput());
}
const auto concat1 = g->concat(inputsT, 0);
const auto concat2 = g->concat(weightsT, 0);
const auto matmul = g->matmul(concat1->getOutput(), concat2->getOutput(),
matmul0->getTransA(), matmul0->getTransB());
g->split(matmul->getOutput(), outputsT, 0, computeOps.size());
// Build computation graph in PET:
g->updateConnection();
Graph graph = new Graph(g->getOperators());
out_graphs.emplace_back(graph);
// DEBUG
dbg(out_graphs);
for (auto graph : out_graphs) {
graph->print();
}
}
NMutator::SGType NMutator::statGraph(Graph sg) {
auto ops = sg->getOperators();
switch (ops.size()) {
case 0: {
return Empty;
break;
}
case 1: {
if (ops[0]->getType() == Operator::Conv) {
auto weight = ops[0]->getInputs()[1];
auto r = weight->getDims()[2];
auto s = weight->getDims()[3];
if (((ConvOp *)sg->getOperators()[0])->getDh() == 1 &&
((ConvOp *)sg->getOperators()[0])->getDw() == 1 && r == 1 &&
s == 1) {
return Conv1X1;
} else if (((ConvOp *)sg->getOperators()[0])->getDh() == 2 ||
((ConvOp *)sg->getOperators()[0])->getDw() == 2) {
return DilatedConv;
} else {
const Dim &inDim = ops[0]->getInputs()[0]->getDims();
const Dim &wDim = ops[0]->getInputs()[1]->getDims();
if (inDim[2] % 2 == 1 && inDim[3] % 2 == 1)
return NormalOddConv;
else if (wDim[2] != wDim[3])
return TransKernelConv;
else
return NormalConv;
}
} else if (ops[0]->getType() == Operator::Matmul) {
return NormalMatmul;
}
break;
}
default:
auto ty = ops[0]->getType();
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i) {
if (ops[i]->getType() != ty)
return Others;
}
if (ty == Operator::Conv) {
std::vector<ConvOp *> convs;
for (auto op : ops)
convs.emplace_back(dynamic_cast<ConvOp *>(op));
// TODO: 1x1 conv enlarge. 1x1 conv has 0 padding
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i)
if (!convs[i]->same(*convs[0]))
return Others;
auto inDim = ops[0]->getInputs(0)->getDims();
// TODO: enlarge input tensor?
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i)
if (ops[i]->getInputs(0)->getDims() != inDim)
return Others;
auto weightDim = ops[0]->getInputs(1)->getDims();
auto groupFlag = true;
// TODO: kernel enlarge to group?
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i) {
auto wDim = ops[i]->getInputs(1)->getDims();
if (!(wDim[1] == weightDim[1] && wDim[2] == weightDim[2] &&
wDim[3] == weightDim[3] && wDim[2] == wDim[3])) {
groupFlag = false;
break;
}
}
if (groupFlag)
return GroupConv;
// Hack for HetConv
if (ops.size() == 2) {
auto w1Dim = weightDim, w2Dim = ops[1]->getInputs(1)->getDims();
auto hConvFlag = false;
if (w1Dim[0] == w2Dim[0] && w1Dim[1] == w2Dim[1]) {
if (w1Dim[2] == 3 && w1Dim[3] == 3 && w2Dim[2] == 1 &&
w2Dim[3] == 1) {
hConvFlag = true;
}
if (w1Dim[2] == 1 && w1Dim[3] == 1 && w2Dim[2] == 3 &&
w2Dim[3] == 3) {
hConvFlag = true;
}
}
if (hConvFlag) {
// std::cout << "[nmutator stat graph]Het Conv found!"
// << std::endl;
// ops[0]->print();
// std::cout << std::endl;
// ops[1]->print();
// std::cout << std::endl;
return HetConv;
}
}
auto transGroupFlag = true;
// TODO: transpose group conv with different f dim?
for (size_t i = 1, iEnd = ops.size(); i < iEnd; ++i) {
auto wDim = ops[i]->getInputs(1)->getDims();
if (!(wDim[0] == weightDim[0] && wDim[1] == weightDim[1] &&
((wDim[2] == weightDim[2] && wDim[3] == weightDim[3]) ||
(wDim[2] == weightDim[3] && wDim[3] == weightDim[2])))) {
transGroupFlag = false;
break;
}
}
if (transGroupFlag)
return TransposeGroupConv;
} else if (ty == Operator::Matmul) {
// check same input shape or not
for (int i = 0; i < (int)ops.size() - 1; ++i) {
assert(dynamic_cast<MatmulOp *>(ops[i])->getTransA() ==
dynamic_cast<MatmulOp *>(ops[i + 1])->getTransA());
assert(dynamic_cast<MatmulOp *>(ops[i])->getTransB() ==
dynamic_cast<MatmulOp *>(ops[i + 1])->getTransB());
if (ops[i]->getInputs()[0]->getDims() !=
ops[i + 1]->getInputs()[0]->getDims()) {
return Others;
}
if (ops[i]->getInputs()[1]->getDims() !=
ops[i + 1]->getInputs()[1]->getDims()) {
return Others;
}
}
return BatchMatmul;
}
// TODO: others?
break;
}
return Others;
}
uint64_t NMutator::computeHashForSingleComputeOp(const Operator op) {
if (op->getType() == Operator::Conv) {
auto conv = dynamic_cast<const ConvOp *>(op);
auto hash = conv->getHash();
auto inputDim = conv->getInputs()[0]->getDims();
auto weightDim = conv->getOutputs()[0]->getDims();
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
weightDim[2] * 10000169 + weightDim[3] * 10000189;
return hash;
} else if (op->getType() == Operator::ConvTrans) {
auto conv = dynamic_cast<const ConvTransOp *>(op);
auto hash = conv->getHash();
auto inputDim = conv->getInputs()[0]->getDims();
auto weightDim = conv->getOutputs()[0]->getDims();
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
weightDim[2] * 10000169 + weightDim[3] * 10000189;
return hash;
} else if (op->getType() == Operator::Matmul) {
static uint64_t matmulhash = 0;
return matmulhash++;
} else if (op->getType() == Operator::G2BMM) {
auto g2bmm = dynamic_cast<const G2BMMOp *>(op);
auto hash = g2bmm->getHash();
auto inputDim = g2bmm->getInputs()[0]->getDims();
auto weightDim = g2bmm->getOutputs()[0]->getDims();
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
weightDim[2] * 10000169 + weightDim[3] * 10000189;
return hash;
} else if (op->getType() == Operator::GBMML) {
auto gbmml = dynamic_cast<const GBMMLOp *>(op);
auto hash = gbmml->getHash();
auto inputDim = gbmml->getInputs()[0]->getDims();
auto weightDim = gbmml->getOutputs()[0]->getDims();
hash += inputDim[0] * 10000019 + inputDim[1] * 10000079 +
inputDim[2] * 10000103 + inputDim[3] * 10000121 +
weightDim[0] * 10000139 + weightDim[1] * 10000141 +
weightDim[2] * 10000169 + weightDim[3] * 10000189;
return hash;
} else {
// Not impl
assert(false);
return 0;
}
}
nnet::Expr NMutator::opToExpression(Operator op) {
if (auto convOp = dynamic_cast<ConvOp *>(op)) {
const auto &inputs = convOp->getInputs();
const auto &AT = inputs[0];
const auto &KT = inputs[1];
const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi, ac] =
convOp->getArgs(0);
dbg(n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw);
if (!(sh == 1 && sw == 1 && dh == 1 && dw == 1))
return nullptr;
assert(sh == 1 && sw == 1 && dh == 1 && dw == 1);
inputsNameNToTensorT["A"] = AT;
inputsNameNToTensorT["K"] = KT;
const auto A = nnet::makeTensor("A", AT->getDims(),
std::vector<int>{0, 0, ph, pw});
const auto K = nnet::makeTensor("K", KT->getDims());
return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s);
} else if (auto convOp = dynamic_cast<ConvTransOp *>(op)) {
const auto &AT = convOp->getInputs()[0];
const auto &KT = convOp->getInputs()[1];
inputsNameNToTensorT["A"] = AT;
inputsNameNToTensorT["K"] = KT;
const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi, ac] =
convOp->getArgs(0);
if (r != 4) {
dbg("ConvTranspose R!=4. Skipped.", r);
return nullptr;
}
int padding = 1 * (r - 1) - 1;
const auto A = nnet::makeTensor(
"A", AT->getDims(), std::vector<int>{0, padding, padding, 0});
const auto K = nnet::makeTensor("K", KT->getDims());
return nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s);
} else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(op)) {
const auto &AT = g2bmmOp->getInputs()[0];
const auto &BT = g2bmmOp->getInputs()[1];
const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
const auto &[expr, inputsN] =
nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
inputsNameNToTensorT[inputsN.first->getName()] = AT;
inputsNameNToTensorT[inputsN.second->getName()] = BT;
return expr;
} else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(op)) {
const auto &AT = gbmmlOp->getInputs()[0];
const auto &BT = gbmmlOp->getInputs()[1];
const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
const auto &[expr, inputsN] =
nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
inputsNameNToTensorT[inputsN.first->getName()] = AT;
inputsNameNToTensorT[inputsN.second->getName()] = BT;
dbg(b, m, w, k, dilation, expr);
return expr;
} else if (auto matmulOp = dynamic_cast<MatmulOp *>(op)) {
const auto &AT = matmulOp->getInputs()[0];
const auto &BT = matmulOp->getInputs()[1];
const auto [transA, transB, b, m, n, k] = matmulOp->getArgs();
const auto &[expr, inputsN] =
nnet::MatmulPattern::getExpr(transA, transB, b, m, n, k);
inputsNameNToTensorT[inputsN.first->getName()] = AT;
inputsNameNToTensorT[inputsN.second->getName()] = BT;
dbg(b, m, n, k, expr);
return expr;
}
// else if (auto transposeOp = dynamic_cast<TransposeOp *>(op)) {
// return transposeOpToExpression(transposeOp);
// }
nnet_unimplemented_continue();
return nullptr;
}
infini::Graph NMutator::fuseHetConv(nnet::Expr expr, Graph in_graph) {
// Conv3x3+Conv1x1 => Gemm(nhw, f(rs+1), c) + Reduce
auto g = std::make_shared<infini::Graph>();
in_graph->print();
assert(in_graph->getInputs().size() == 3);
auto input = in_graph->getOperators()[0]->getInputs(0);
auto conv = dynamic_cast<ConvOp *>(in_graph->getOperators()[0]);
auto output = conv->getOutput();
// auto input = g->reshape(input);
auto inputTrans = g->transpose(input, 0, {-1, {0, 2, 3}, 1}, -1);
// dbg(inputTrans->getOutput()->getDims());
const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, G, bi, ac] =
conv->getArgs(0);
auto weight = g->tensor({1, c, f * (3 * 3 + 1)});
dbg(weight->getDims());
auto matmul = g->matmul(inputTrans->getOutput(), weight, false, false);
auto bias = g->tensor({f});
const double size = n * f * h * w * (3 * 3 + 1) * 4;
// FIXME: add NNET tensors for verfication
auto membound =
g->membound({matmul->getOutput(), bias}, {output}, {}, nullptr,
memboundTime(size), "Reduce_conv3x3+1x1");
dbg(n, f, h, w);
dynamic_cast<MemBoundOp *>(membound)->setNFHW(n, f, h, w);
return new Graph(g->getOperators());
}
infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
auto g = new infini::Graph();
nnet::FullPrinterVisitor fullVisitor;
const auto &tensorQueueN = fullVisitor.traverse(expr);
// Build tensors: Skip the first one, which is output
auto nameNToTensorT = inputsNameNToTensorT;
for (size_t i = 1; i < tensorQueueN.size(); ++i) {
const auto &[nameN, routineN, tensorN] = tensorQueueN[i];
// dbg(nameN, routineN, tensorN);
if (!routineN) {
// This is an inputs
assert(nameNToTensorT.count(nameN));
} else {
assert(!nameNToTensorT.count(nameN));
nameNToTensorT[nameN] = g->tensor(tensorN->getShape());
}
}
const auto &outputsPET = in_graph->getOutputs();
if (outputsPET.size() != 1) {
nnet_unimplemented_continue();
return nullptr;
}
nameNToTensorT[std::get<0>(tensorQueueN.at(0))] = outputsPET[0];
// Build computation graph in PET:
for (int i = tensorQueueN.size() - 1; i >= 0; --i) {
const auto &[outputNameN, routineN, tensorN] = tensorQueueN[i];
if (!routineN)
continue;
// dbg(outputNameN, routineN, tensorN, routineN->getType());
if (auto op = nnet::as<nnet::ConvNode>(routineN)) {
// g->conv(i8, w9, 2, 2);
std::vector<nnet::Tensor> inputsN = op->getInputs();
auto A = nameNToTensorT.at(inputsN[0]->getName());
auto K = nameNToTensorT.at(inputsN[1]->getName());
auto output = nameNToTensorT.at(outputNameN);
const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
g->conv(A, K, output, ph, pw, sh, sw, dh, dw);
} else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
assert(op->getInputs().size() == 1);
nnet::MatchReshapeVisitor matchReshapeVisitor;
if (matchReshapeVisitor(op->getExpr())) {
auto input =
nameNToTensorT.at(op->getInputs().at(0)->getName());
auto output = nameNToTensorT.at(outputNameN);
g->reshape(input, output);
} else {
TensorVec inputsPET;
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
for (const auto &inputN : op->getInputs())
inputsPET.emplace_back(
nameNToTensorT.at(inputN->getName()));
// Re-estimate time here.
ssize_t cnt = 0;
for (const auto tensor : inputsPET)
cnt += tensor->size();
for (const auto tensor : outputsPET)
cnt += tensor->size();
g->membound(inputsPET, outputsPET, op->getInputs(),
op->getExpr(), memboundTime(cnt));
}
} else if (auto op = nnet::as<nnet::MatmulNode>(routineN)) {
assert(op->getInputs().size() == 2);
nnet::Tensor AN = op->getInputs()[0];
nnet::Tensor BN = op->getInputs()[1];
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
nameNToTensorT.at(BN->getName())};
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
const auto &[b, m, n, k, transa, transb] = op->getArgs();
g->matmul(inputsPET[0], inputsPET[1], outputsPET[0], transa,
transb);
} else if (auto op = nnet::as<nnet::G2bmmNode>(routineN)) {
assert(op->getInputs().size() == 2);
nnet::Tensor AN = op->getInputs()[0];
nnet::Tensor BN = op->getInputs()[1];
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
nameNToTensorT.at(BN->getName())};
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
const auto &[b, m, w, k, dilation] = op->getArgs();
g->g2bmm(inputsPET[0], inputsPET[1], outputsPET[0], w, dilation);
} else if (auto op = nnet::as<nnet::GbmmNode>(routineN)) {
assert(op->getInputs().size() == 2);
nnet::Tensor AN = op->getInputs()[0];
nnet::Tensor BN = op->getInputs()[1];
TensorVec inputsPET = {nameNToTensorT.at(AN->getName()),
nameNToTensorT.at(BN->getName())};
TensorVec outputsPET = {nameNToTensorT.at(outputNameN)};
const auto &[b, m, w, n, dilation] = op->getArgs();
g->gbmml(inputsPET[0], inputsPET[1], outputsPET[0], dilation);
}
}
g->updateConnection();
Graph graph = new Graph(g->getOperators());
return graph;
}
Graph NMutator::transformDialtedConv(Operator op) {
if (auto convOp = dynamic_cast<ConvOp *>(op)) {
if (convOp->getPh() == convOp->getDh() && convOp->getSh() == 1 &&
convOp->getDh() > 1 && convOp->getDh() == convOp->getDw()) {
const int d = convOp->getDh();
assert(convOp->getInputs()[0]->getDims()[2] % d == 0);
auto g = new infini::Graph();
auto inputDims = convOp->getInputs(0)->getDims();
auto weightDims = convOp->getInputs(1)->getDims();
auto outputDims = convOp->getOutput()->getDims();
auto newA = g->tensor({inputDims[0] * d * d, inputDims[1],
inputDims[2] / d, inputDims[3] / d});
// auto newW = g->tensor(
// {weightDims[0] * weightDims[1] * weightDims[3],
// weightDims[2]});
auto newO =
g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
weightDims[0] * weightDims[1] * weightDims[3]});
g->membound(
{convOp->getInputs(0)}, {newA}, {}, nullptr,
memboundTime(convOp->getInputs(0)->size() + newA->size()),
"DConv Tranpose Input");
// g->membound({convOp->getInputs(1)}, {newW}, {}, nullptr, 0,
// "Tranpose Weight");
g->conv(newA, convOp->getInputs(1), newO, 1, 1, 1, 1);
g->reshape(newO, convOp->getOutput());
dbg(newO->size(), convOp->getOutput()->size());
assert(newO->size() == convOp->getOutput()->size());
g->membound(
{newO}, {convOp->getOutput()}, {}, nullptr,
memboundTime(newO->size() + convOp->getOutput()->size()),
"DConv Tranpose Output");
g->updateConnection();
Graph graph = new Graph(g->getOperators());
return graph;
}
}
return nullptr;
}
double NMutator::memboundTime(ssize_t cnt) {
return double(cnt) * 4 / bandwidth * 1000; // millisecond
}
double NMutator::memboundTime(const Dim &dims) {
return memboundTime(dims.size());
}
Graph NMutator::transformTConv3x3(Operator op) {
if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
dbg(tconvOp->getInputs()[1]->getDims());
if (tconvOp->getPh() == 1 && tconvOp->getSh() == 2 &&
tconvOp->getInputs()[1]->getDims()[0] == 3 &&
tconvOp->getInputs()[1]->getDims()[1] == 3) {
auto g = new infini::Graph();
auto inputDims = tconvOp->getInputs(0)->getDims();
auto weightDims = tconvOp->getInputs(1)->getDims();
auto outputDims = tconvOp->getOutput()->getDims();
// NHWF
auto newA = g->tensor(
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
// RSFC
auto newW = g->tensor(
{weightDims[0] * weightDims[1] * weightDims[3], weightDims[2]});
auto newO =
g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
weightDims[0] * weightDims[1] * weightDims[3]});
g->reshape(tconvOp->getInputs(0), newA);
g->reshape(tconvOp->getInputs(1), newW);
g->matmul(newA, newW, newO, 0, 1);
// g->reshape(newO, tconvOp->getOutput());
tconvOp->print();
dbg(newO->size() * 4, tconvOp->getOutput()->size() * 9);
assert(newO->size() * 4 == tconvOp->getOutput()->size() * 9);
g->membound(
{newO}, {tconvOp->getOutput()}, {}, nullptr,
memboundTime(newO->size() + tconvOp->getOutput()->size()),
"TConv3x3 reduce");
g->updateConnection();
Graph graph = new Graph(g->getOperators());
return graph;
}
}
return nullptr;
}
Graph NMutator::transformTConv1x1(Operator op) {
if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
if (tconvOp->getPh() == 0 && tconvOp->getSh() == 1) {
auto g = new infini::Graph();
auto inputDims = tconvOp->getInputs(0)->getDims();
auto weightDims = tconvOp->getInputs(1)->getDims();
auto outputDims = tconvOp->getOutput()->getDims();
auto newA = g->tensor(
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
auto newW = g->tensor(
{weightDims[0] * weightDims[1] * weightDims[3], weightDims[2]});
auto newO =
g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
weightDims[0] * weightDims[1] * weightDims[3]});
g->reshape(tconvOp->getInputs(0), newA);
g->reshape(tconvOp->getInputs(1), newW);
g->matmul(newA, newW, newO, 0, 1);
g->reshape(newO, tconvOp->getOutput());
g->updateConnection();
Graph graph = new Graph(g->getOperators());
return graph;
}
}
return nullptr;
}
Graph NMutator::transformConv1x1(Operator op) {
auto convOp = dynamic_cast<ConvOp *>(op);
if (!convOp)
return nullptr;
if (convOp->getPh() == 0 && convOp->getSh() == 1 &&
convOp->getInputs()[1]->getDims()[2] == 1 &&
convOp->getInputs()[1]->getDims()[3] == 1) {
// Transpose is requrired for BS>1
// if (convOp->getInputs()[0]->getDims()[0] == 1) {
auto g = new infini::Graph();
auto inputDims = convOp->getInputs(0)->getDims();
auto weightDims = convOp->getInputs(1)->getDims();
auto outputDims = convOp->getOutput()->getDims();
auto newA = g->tensor(
{inputDims[1], inputDims[0] * inputDims[2] * inputDims[3]});
auto newW = g->tensor({weightDims[0], weightDims[1]});
auto newO = g->tensor(
{weightDims[0], inputDims[0] * inputDims[2] * inputDims[3]});
g->reshape(convOp->getInputs(0), newA);
g->reshape(convOp->getInputs(1), newW);
g->matmul(newW, newA, newO, 0, 0);
g->reshape(newO, convOp->getOutput());
g->updateConnection();
Graph graph = new Graph(g->getOperators());
return graph;
}
return nullptr;
}
Graph NMutator::transformConv1xk(Operator op) {
auto convOp = dynamic_cast<ConvOp *>(op);
if (!convOp)
return nullptr;
if (convOp->getSh() != 1 || convOp->getSw() != 1)
return nullptr;
bool a = convOp->getInputs()[1]->getDims()[2] == 1;
bool b = convOp->getInputs()[1]->getDims()[3] == 1;
if (!(a ^ b))
return nullptr;
convOp->print();
auto g = new infini::Graph();
auto inputDims = convOp->getInputs(0)->getDims();
auto weightDims = convOp->getInputs(1)->getDims();
auto outputDims = convOp->getOutput()->getDims();
auto newA =
g->tensor({inputDims[0] * inputDims[2] * inputDims[3], inputDims[1]});
auto newW = g->tensor(
{weightDims[0] * weightDims[2] * weightDims[3], weightDims[1]});
auto newO = g->tensor({weightDims[0] * weightDims[2] * weightDims[3],
inputDims[0] * inputDims[2] * inputDims[3]});
// g->reshape(convOp->getInputs(0), newA);
g->membound({convOp->getInputs(0)}, {newA}, {}, nullptr,
memboundTime(convOp->getInputs(0)->size() + newA->size()),
"1xk input reshape");
g->reshape(convOp->getInputs(1), newW);
g->matmul(newW, newA, newO, 0, 1);
g->membound({newO}, {convOp->getOutput()}, {}, nullptr,
memboundTime(newW->size() + convOp->getOutput()->size()),
"1xk reduce");
g->updateConnection();
Graph graph = new Graph(g->getOperators());
return graph;
}
} // namespace infini
#endif

35
src/nnet/permutation.cc Normal file
View File

@ -0,0 +1,35 @@
#include "nnet/permutation.h"
#include <algorithm>
namespace nnet {
PermutationGenerator::PermutationGenerator(vector<vector<Iterator>> _from,
vector<vector<Iterator>> _to)
: from(_from), to(_to), mapping(from.size()) {
assert(from.size() == to.size());
for (size_t i = 0; i < from.size(); ++i)
for (size_t j = 0; j < from[i].size(); ++j)
mapping[i].emplace_back(j);
}
bool PermutationGenerator::next() {
if (mapping.empty())
return false;
for (int i = (int)mapping.size() - 1; i >= 0; --i) {
if (std::next_permutation(mapping[i].begin(), mapping[i].end()))
return true;
}
return false;
}
PtrMap<Iterator, Iterator> PermutationGenerator::get() const {
if (mapping.empty())
return {};
PtrMap<Iterator, Iterator> ret;
for (size_t i = 0; i < mapping.size(); ++i)
for (size_t j = 0; j < mapping[i].size(); ++j)
ret[from[i][j]] = to[i][mapping[i][j]];
return ret;
}
} // namespace nnet

92
src/nnet/replace_kit.cc Normal file
View File

@ -0,0 +1,92 @@
#include "nnet/ReplaceKit.h"
#include "nnet/Visitor/ReplaceVariable.h"
#include "nnet/Visitor/SimplifyFormulaMutator.h"
namespace nnet {
RangeOp ReplaceKit::replaceRangeOpIterator(const RangeOp &rangeOp,
const Replace &replace,
const Expr &replacedSummand) {
vector<VarRangePair> newVarRangePairs(replace.newVarRanges);
if (replace.iteratorType == IterationType::Loop) {
for (const auto &[var, range] : rangeOp->getLoopVarRanges()) {
if (!replace.isReplaced(var))
newVarRangePairs.emplace_back(var, range);
}
assert(newVarRangePairs.size() == rangeOp->getLoopVarRanges().size() -
replace.oldIters.size() +
replace.newIters.size());
// Check the number of loop iterators
return makeRangeOperator(newVarRangePairs, rangeOp->getSumVarRanges(),
replacedSummand);
} else if (replace.iteratorType == IterationType::Sum) {
for (const auto &[var, range] : rangeOp->getSumVarRanges()) {
if (!replace.isReplaced(var))
newVarRangePairs.emplace_back(var, range);
}
assert(newVarRangePairs.size() == rangeOp->getSumVarRanges().size() -
replace.oldIters.size() +
replace.newIters.size());
return makeRangeOperator(rangeOp->getLoopVarRanges(), newVarRangePairs,
replacedSummand, rangeOp->getPaddings());
}
assert(false);
return nullptr;
}
Subscript ReplaceKit::buildSubscirptForLoopVarReplace(const RangeOp &inner,
const Replace &replace) {
VecExpr subs(replace.phis);
for (size_t i = 0; i < replace.newVarRanges.size(); ++i) {
assert(replace.newIters[i]->equal(inner->getLoopVar(i)));
}
for (size_t i = replace.newVarRanges.size();
i < inner->getLoopVarRanges().size(); ++i) {
subs.emplace_back(inner->getLoopVar(i));
}
// The support of var reorder and replace at the same time
// VecExpr subs;
// for (size_t i = 0; i < inner->getLoopVarRanges().size(); ++i) {
// if (auto it = std::find(replace.newIters.begin(),
// replace.newIters.end(),
// inner->getLoopVar(i));
// it != replace.newIters.end()) {
// subs.emplace_back(replace.phis[it - replace.newIters.begin()]);
// } else
// subs.emplace_back(inner->getLoopVar(i));
// }
return makeSubscript(inner, subs);
}
RangeOp
ReplaceKit::buildDLTOuterRangeOp(const RangeOp &original,
const Subscript &subscriptedNewRangeOp) {
auto outer = make_ref<RangeOpNode>(*original);
outer->setSummand(subscriptedNewRangeOp);
outer->setSumIterator({});
return outer;
}
Expr ReplaceKit::replaceMultipleExprs(const Expr &cur,
const vector<Var> &patterns,
const VecExpr &replacements,
bool simplify) {
auto ret = cur;
for (size_t i = 0; i < patterns.size(); ++i) {
ret = replaceExpr(ret, patterns[i], replacements[i]);
}
if (simplify) {
SimplifyFormulaMutator simplifyFormulaMutator;
ret = simplifyFormulaMutator.simplify(ret);
}
return ret;
}
Expr ReplaceKit::replaceExpr(const Expr &cur, const Expr &pattern,
const Expr &replacement) {
auto replace = ReplaceVariable(pattern, replacement);
auto ret = replace(cur);
return ret;
}
} // namespace nnet

136
src/nnet/routine.cc Normal file
View File

@ -0,0 +1,136 @@
#include "nnet/routine.h"
#include "nnet/Visitor/CloneMutator.h"
#include <algorithm>
namespace nnet {
RoutineNode::RoutineNode(Expr _expr, const vector<Tensor> &_inputs)
: inputs(_inputs) {
if (_expr)
expr = CloneMutator().clone(_expr);
else
expr = nullptr;
}
string MatmulNode::toReadable() const {
std::ostringstream oss;
assert(inputs.size() == 2);
oss << "Matmul{bmnk = " << b << ", " << m << ", " << n << ", " << k
<< "; AB = " << inputs[0] << ", " << inputs[1]
<< "; transAB = " << transa << ", " << transb << "}";
return oss.str();
}
bool operator==(const MatmulNode &lhs, const MatmulNode &rhs) {
if (!((lhs.b == rhs.b) && lhs.m == rhs.m && lhs.n == rhs.n &&
lhs.k == rhs.k && lhs.transa == rhs.transa &&
lhs.transb == rhs.transb))
return false;
if (lhs.inputs.size() != rhs.inputs.size())
return false;
for (size_t i = 0; i < lhs.inputs.size(); ++i) {
if (lhs.inputs[i]->getName() != rhs.inputs[i]->getName())
return false;
}
return true;
}
string ConvNode::toReadable() const {
std::ostringstream oss;
assert(inputs.size() == 2);
oss << "Conv{A =" << inputs[0]
<< " shape=" << serializeVec(inputs[0]->getShape())
<< ", K=" << inputs[1]
<< " shape=" << serializeVec(inputs[1]->getShape()) << ", p = " << ph
<< ", " << pw << ", s= " << sh << ", " << sw << ", d= " << dh << ", "
<< dw << "}";
return oss.str();
}
bool operator==(const ConvNode &lhs, const ConvNode &rhs) {
if (!(lhs.ph == rhs.ph && lhs.pw == rhs.pw && lhs.sh == rhs.sh &&
lhs.sw == rhs.sw && lhs.dh == rhs.dh && lhs.dw == rhs.dw))
return false;
if (lhs.inputs.size() != rhs.inputs.size())
return false;
for (size_t i = 0; i < lhs.inputs.size(); ++i) {
if (lhs.inputs[i]->getName() != rhs.inputs[i]->getName())
return false;
}
return true;
}
vector<int> ConvNode::getShape() const {
auto input = inputs[0], weight = inputs[1];
auto n = input->getShape(0);
auto h = input->getShape(2);
auto w = input->getShape(3);
auto f = weight->getShape(0);
auto r = weight->getShape(2);
auto s = weight->getShape(3);
int on = n, oc = f;
int oh = 0, ow = 0;
// Set padding size
oh = (h - (r - sh) * dh + ph * 2) / sh;
ow = (w - (s - sw) * dw + pw * 2) / sw;
auto ret = {on, oc, oh, ow};
return ret;
}
ConvArgs ConvNode::getArgs() const { return tuple(ph, pw, sh, sw, dh, dw); }
vector<int> G2bmmNode::getShape() const { return {b, m, 2 * w + 1}; }
vector<int> GbmmNode::getShape() const { return {b, m, n}; }
string ElementWiseNode::toReadable() const {
std::ostringstream oss;
oss << "EleWise{";
for (const auto &input : inputs)
oss << input << ", ";
oss << "}";
return oss.str();
}
double ElementWiseNode::getEstimatedTime() const {
int64_t cntElements = 0;
// For unimplemented transpose
assert(inputs.size() > 0);
if (!expr) {
assert(inputs.size() == 1);
}
for (const auto &input : inputs)
cntElements += input->getSize();
int64_t outputSize = 1;
for (const auto &len : outputShape)
outputSize *= len;
cntElements += outputSize;
const double bandwidth = 200 * 1000000;
// dbg(inputs, inputs[0]->getShape(), cntElements,
// (cntElements * 4) / bandwidth);
return double(cntElements * 4) / bandwidth; // ms
}
string G2bmmNode::toReadable() const {
std::ostringstream oss;
oss << "G2bmm{";
for (const auto &input : inputs)
oss << input << ", ";
oss << ", bmwk = " << b << " " << m << " " << w << " " << k << "}";
return oss.str();
}
string GbmmNode::toReadable() const {
std::ostringstream oss;
oss << "Gbmm{";
for (const auto &input : inputs)
oss << input << ", ";
oss << ", bmwn = " << b << " " << m << " " << w << " " << n << "}";
return oss.str();
}
G2bmmArgs G2bmmNode::getArgs() const { return {b, m, w, k, 1}; }
GbmmArgs GbmmNode::getArgs() const { return {b, m, w, n, 1}; }
} // namespace nnet

70
src/nnet/test.cc Normal file
View File

@ -0,0 +1,70 @@
#include "nnet/Visitor/FullPrinterVisitor.h"
#include "nnet/Visitor/GetTensorsVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/Interpreter.h"
#include "nnet/Visitor/Serializer.h"
namespace nnet {
int matchExprResult(Derivator &derivator, string fn) {
auto ans = Serializer().deserialize(fn);
auto hashAns = HashVisitor()(ans);
int match = 0;
for (const auto &candidate : derivator.getCandidates()) {
auto hashCandidate = HashVisitor()(candidate.root);
match += (hashAns == hashCandidate);
}
return match;
}
bool checkExprLogSame(string fnPrefix, int start, int end) {
Serializer serializer;
string fn0 = fnPrefix + to_string(start) + ".expr";
Expr expr0 = serializer.deserialize(fn0);
RangeOp range0 = as<RangeOpNode>(expr0);
Interpreter interpreter(range0);
auto ans0 = interpreter.interpretUniformSample(range0);
dbg(expr0, ans0);
for (int i = start + 1; i < end; ++i) {
string fn1 = fnPrefix + to_string(i) + ".expr";
Expr expr1 = serializer.deserialize(fn1);
RangeOp range1 = as<RangeOpNode>(expr1);
dbg(fn1, expr1);
auto ans1 = interpreter.interpretUniformSample(range1);
dbg(ans1);
if (ans0.size() != ans1.size())
return false;
for (size_t i = 0; i < ans0.size(); ++i)
if (ans0[i] != ans1[i])
return false;
}
return true;
}
bool checkExprsEquvivalence(VecExpr exprs) {
if (exprs.size() < 2)
return true;
auto inputsMap0 = GetTensorsVisitor().get(exprs[0]);
RangeOp range0 = as<RangeOpNode>(exprs[0]);
Interpreter interpreter(range0);
auto ans0 = interpreter.interpretUniformSample(range0);
for (size_t i = 1; i + 1 < exprs.size(); ++i) {
RangeOp range1 = as<RangeOpNode>(exprs[i]);
auto inputsMap1 = GetTensorsVisitor().get(range1);
// if expr0 and expr1 have different inputs, skip and return true
if (inputsMap0.size() != inputsMap1.size())
return true;
for (const auto &[name, tensor] : inputsMap0) {
if (!inputsMap1.count(name))
return true;
}
auto ans1 = interpreter.interpretUniformSample(range1);
if (ans0.size() != ans1.size())
return false;
for (size_t i = 0; i < ans0.size(); ++i)
if (ans0[i] != ans1[i])
return false;
}
return true;
}
} // namespace nnet

12
src/nnet/utils.cc Normal file
View File

@ -0,0 +1,12 @@
#include "nnet/common.h"
namespace nnet {
std::string pointer_to_hex(void *i) {
std::stringstream stream;
// stream << "0x" << std::setfill('0') << std::setw(sizeof(void *) * 2) <<
// std::hex
// << i;
stream << std::hex << i;
return stream.str();
}
} // namespace nnet

97
src/nnet/visitor.cc Normal file
View File

@ -0,0 +1,97 @@
#include "nnet/visitor.h"
namespace nnet {
Expr Mutator::visit_([[maybe_unused]] const Constant &c) { return nullptr; }
Expr Mutator::visit_(const BinaryOp &c) {
if (verbose)
dbg(*c);
bool modified = false;
auto ret = make_ref<BinaryOpNode>(*c);
if (auto e = this->dispatch(ret->getLhs()); e) {
modified = true;
ret->setLhs(e);
}
if (auto e = this->dispatch(ret->getRhs()); e) {
modified = true;
ret->setRhs(e);
}
return (modified) ? ret : nullptr;
}
Expr Mutator::visit_(const RangeOp &c) {
if (verbose)
dbg(*c);
bool modified = false;
auto ret = make_ref<RangeOpNode>(*c);
if (auto mutated = this->dispatch(ret->getSummand()); mutated) {
modified = true;
ret->setSummand(mutated);
}
// NOT visit iterators and its ranges
return (modified) ? ret : nullptr;
}
Expr Mutator::visit_(const Subscript &c) {
if (verbose)
dbg(*c);
bool modified = false;
auto ret = make_ref<SubscriptNode>(*c);
for (size_t i = 0; i < ret->getDims(); ++i) {
const auto &e = ret->getIndex(i);
if (const auto &mutated = this->dispatch(e); mutated) {
modified = true;
ret->setIndex(i, mutated);
}
}
if (auto mutated = this->dispatch(ret->getObject()); mutated) {
modified = true;
ret->setObject(mutated);
}
return (modified) ? ret : nullptr;
}
Expr Mutator::visit_([[maybe_unused]] const Var &c) { return nullptr; }
Expr Mutator::visit_([[maybe_unused]] const Tensor &c) { return nullptr; }
Expr Mutator::visit_(const Func &c) {
if (verbose)
dbg(c);
bool modified = false;
auto ret = make_ref<FuncNode>(*c);
if (auto mutated = dispatch(c->getObject())) {
modified = true;
ret->setObject(mutated);
}
return (modified) ? ret : nullptr;
}
void ExprTreeVisitor::visit_(const RangeOp &c) {
if (inRange)
dispatch(c->getSummand());
}
void ExprTreeVisitor::visit_(const BinaryOp &c) {
if (inBinary) {
dispatch(c->getLhs());
dispatch(c->getRhs());
}
}
void ExprTreeVisitor::visit_(const Subscript &c) {
if (inSub) {
dispatch(c->getObject());
for (const auto &index : c->getIndex())
dispatch(index);
}
}
void ExprTreeVisitor::visit_([[maybe_unused]] const Var &c) {}
void ExprTreeVisitor::visit_([[maybe_unused]] const Constant &c) {}
void ExprTreeVisitor::visit_(const Tensor &c) {
if (inTensor && c->getSource()) {
if (const auto &expr = c->getSource()->getExpr(); expr)
dispatch(expr);
}
}
void ExprTreeVisitor::visit_(const Func &c) { dispatch(c->getObject()); }
} // namespace nnet

View File

@ -0,0 +1,49 @@
#include "code_engine.h"
#include "nnet/expr.h"
#include "nnet/nmutator.h"
#include "operator.h"
#include "search_engine.h"
#include "tensor.h"
#include "gtest/gtest.h"
using namespace nnet;
using namespace std;
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
TEST(Activation, Relu) {
const int n_heads = 8, seq_len = 10000, feat_len = 512;
// dilation_heads = 2;
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
auto subA = makeSubscript(A, {b, m, k});
auto innerRange = makeRangeOperator(
{{b, {0, Batch}}, {m, {0, M}}, {k, {0, K}}}, {}, subA);
auto outerSub = makeSubscript(innerRange, {b, m, k});
// auto subB = makeSubscript(B, {b, m + dilation * (w - W), k});
auto relu = make_ref<FuncNode>(subA, FuncType::Relu);
auto range =
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {0, 2 * W + 1}}},
{{k, {0, K}}}, relu);
dbg(range);
auto g = new tpm::Graph();
auto i0 = g->tensor({Batch, M, K});
auto i1 = g->tensor({Batch, M, 2 * W + 1});
tpm::TensorVec inputsT{i0};
tpm::TensorVec outputsT{i1};
g->membound(inputsT, outputsT, {A}, range, 0);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>());
searchEngine.run(graph, bestGraph);
tpm::CodeEngine codeEngine;
codeEngine.genCode(bestGraph, "res.cu");
}

Some files were not shown because too many files have changed in this diff Show More