Add: Figure 17

This commit is contained in:
Liyan Zheng 2023-04-24 21:07:30 +08:00
parent 1e46750159
commit 11229a2baa
2 changed files with 68 additions and 23 deletions

View File

@ -136,9 +136,10 @@ void Derivator::dfs(Formula &origin, int depth) {
}
Derivator::Derivator(int maxDepth, bool enableHashPruning, LogMode logMode,
PassMode passMode)
PassMode passMode, bool printAndExit)
: maxDepth(maxDepth), logMode(logMode), passMode(passMode),
enableHashPruning(enableHashPruning), cntAppliedRules(12) {}
enableHashPruning(enableHashPruning), cntAppliedRules(12),
printAndExit(printAndExit) {}
int Derivator::getNumIntermediateStates() { return cntStates; }
@ -405,6 +406,8 @@ Expr Derivator::mergeMemboundStages(VecExpr stages) {
void Derivator::appendCanddiate(const Tensor &tensor, int depth) {
// if (!CountRoutineVisitor().match(tensor, 1, 0, 3))
// return;
if (intermediateStates.size() > 1 && printAndExit)
printDerivationRules();
candidates.emplace_back(tensor, depth);
// dbg("!!!!!!!!!!!!!!!Success!!!!!!!!!!!!!!!");
@ -478,6 +481,7 @@ void Derivator::printStatistics() {
printf("#Hashed intermediate states = %lu\n", visited.size());
printf("#Iteratos = %d\n", nIteratorNames);
printf("#Tensors = %d\n", nTensorNames);
printf("#Print and Exit mode = %d\n", printAndExit);
}
void Derivator::setDumpFirstSuccess(const string &_logFnPrefix) {
@ -490,6 +494,9 @@ void Derivator::printIntermediateStates() {
// Skip in NoLog mode
if (logMode == LogMode::NoLog)
return;
if (intermediateStates.size() > 1 && printAndExit)
printDerivationRules();
assert(intermediateStates.size() == ruleStates.size());
assert(intermediateStates.size() == ruleMsgs.size());
for (size_t i = 0; i < intermediateStates.size(); ++i) {
@ -503,12 +510,13 @@ void Derivator::printIntermediateStates() {
logFnPrefix + to_string(i) + ".expr", msg);
}
}
for (size_t i = 0; i < intermediateStates.size(); ++i) {
if (auto cur = as<RangeOpNode>(intermediateStates[i]))
if (CheckOOBVisitor().checkRangeOp(cur)) {
printf("OOB detected depth=%lu\n", i);
}
}
// FIXME
// for (size_t i = 0; i < intermediateStates.size(); ++i) {
// if (auto cur = as<RangeOpNode>(intermediateStates[i]))
// if (CheckOOBVisitor().checkRangeOp(cur)) {
// printf("OOB detected depth=%lu\n", i);
// }
// }
if (logMode == LogMode::DumpFristCandiate) {
puts("Serializaiton finished.");
exit(0);
@ -528,4 +536,21 @@ Derivator::PassMode Derivator::getPassMode() { return passMode; }
Derivator::LogMode Derivator::getLogMode() { return logMode; }
void Derivator::printDerivationRules() {
int cntRules = 0, cntNonGuideRules = 0;
bool startGuided = false;
std::cout << ruleStates.size() << "rules" << std::endl;
for (size_t i = 1; i < ruleStates.size(); ++i) {
int ruleId = ruleStates[i][4] - '0';
if (ruleId != 4)
++cntRules;
if (ruleId == 8)
startGuided = true;
if (!startGuided && ruleId != 4)
++cntNonGuideRules;
}
printf("#Steps w/o converging derivation %d, #Steps w/ converging derivation %d\n", cntRules, cntNonGuideRules);
exit(0);
}
} // namespace nnet

View File

@ -6,25 +6,35 @@
using namespace nnet;
using namespace std;
// clang-format off
/* Evaluation bash script
# Maxdepth
for i in $(seq 1 12); do
echo $i
NNET_UseHash=1 NNET_MaxDepth=$i ./test_OpSearch # &>
out.searchDepthTest.$i.txt done
NNET_UseHash=1 NNET_MaxDepth=$i ./test_OpSearch &> out.searchDepthTest.$i.txt done
NNET_UseHash=0 NNET_MaxDepth=8 ./test_OpSearch # &> out.searchDepthTest.$i.txt
NNET_UseHash=1 NNET_MaxDepth=8 ./test_OpSearch # &> out.searchDepthTest.$i.txt
# Enable/disable hash
NNET_UseHash=0 NNET_MaxDepth=8 ./test_OpSearch &> out.searchDepthTest.$i.txt
NNET_UseHash=1 NNET_MaxDepth=8 ./test_OpSearch &> out.searchDepthTest.$i.txt
# Number of derivation steps
for i in Conv3x3 ConvTranspose Conv5x5 G2BMM; do
NNET_PrintAndExit=1 NNET_UseHash=1 NNET_MaxDepth=7 ./test_OpSearch --gtest_filter="*$i" > out.steps.$i.txt done
*/
// clang-format on
class OpSearch : public ::testing::Test {
protected:
const int maxDepth = getMaxDepth();
const int useHash = getUseHash();
const bool printAndExit = getPrintAndExit();
// const int maxDepth = 8;
// const int useHash = true;
const Derivator::LogMode mode = Derivator::LogMode::NoLog;
const Derivator::LogMode mode = getPrintAndExit()
? Derivator::LogMode::DumpFristCandiate
: Derivator::LogMode::NoLog;
const Derivator::PassMode passMode = Derivator::PassMode::Full;
const bool isRuleBased = false;
const bool isRuleBased = getPrintAndExit();
void SetUp() override {
if (maxDepth < 0 || useHash < 0) {
@ -45,9 +55,16 @@ class OpSearch : public ::testing::Test {
return atoi(s);
return -1;
}
static bool getPrintAndExit() {
if (auto s = getenv("NNET_PrintAndExit"))
return atoi(s);
return 0;
}
};
TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
// TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
TEST_F(OpSearch, Conv3x3) {
// A[n,h+r,w+s,c]*K[r,s,f,c]
int N = 1, H = 7, W = 7, C = 512, F = 512;
int R = 3, S = 3;
@ -72,7 +89,7 @@ TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
// Derivation
Formula conv_3x3_nhwc_rsfc(range, 0);
Derivator derivator(maxDepth, useHash, mode, passMode);
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
if (isRuleBased) {
// Rule-based derivation
@ -142,13 +159,14 @@ RangeOp buildTConv4x4_NHWF_RSFC(const int N, const int C, const int H,
return range0;
}
TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
// TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
TEST_F(OpSearch, ConvTranspose) {
const int N = 16, H = 2, W = 2, C = 256, F = 448;
const int R = 4, S = 4;
RangeOp range = buildTConv4x4_NHWF_RSFC(N, C, H, W, F, R, S);
Formula conv_9x9(range, 0);
Derivator derivator(maxDepth, useHash, mode, passMode);
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
if (isRuleBased) {
const vector<int> rules{3, 2, 2, 2, 2, 5};
@ -165,7 +183,8 @@ TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
derivator.printStatistics();
}
TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
// TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
TEST_F(OpSearch, Conv5x5) {
int N = 16, C = 32, H = 224, W = 224, F = 1;
int R = 5, S = 5;
auto n = make_ref<VarNode>("n");
@ -187,7 +206,7 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
{{c, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA * subK);
Formula conv_9x9(range, 0);
Derivator derivator(maxDepth, useHash, mode, passMode);
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
if (isRuleBased) {
const vector<int> rules{9, 1, 1, 3, 2, 2, 5, 8, 8, 6, 6};
@ -204,7 +223,8 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
EXPECT_GE(nMatches, 1);
}
TEST_F(OpSearch, G2BMM_RuleBased) {
// TEST_F(OpSearch, G2BMM_RuleBased) {
TEST_F(OpSearch, G2BMM) {
const int n_heads = 8, seq_len = 10000, feat_len = 512;
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32, dilation = 4;
DEFINE_VAR(b);
@ -224,7 +244,7 @@ TEST_F(OpSearch, G2BMM_RuleBased) {
// Derivation: this work without padding check in stage merging
Formula dialted_g2bmm(range, 0);
Derivator derivator(maxDepth, useHash, mode, passMode);
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
if (isRuleBased) {
const vector<int> rules{1, 7, 7, 2, 8, 6, 6};
@ -240,4 +260,4 @@ TEST_F(OpSearch, G2BMM_RuleBased) {
EXPECT_GE(nMatches, 1);
// derivator.print();
derivator.printStatistics();
}
}