forked from jiuyuan/InfiniTensor
Add: Figure 17
This commit is contained in:
parent
1e46750159
commit
11229a2baa
|
@ -136,9 +136,10 @@ void Derivator::dfs(Formula &origin, int depth) {
|
|||
}
|
||||
|
||||
Derivator::Derivator(int maxDepth, bool enableHashPruning, LogMode logMode,
|
||||
PassMode passMode)
|
||||
PassMode passMode, bool printAndExit)
|
||||
: maxDepth(maxDepth), logMode(logMode), passMode(passMode),
|
||||
enableHashPruning(enableHashPruning), cntAppliedRules(12) {}
|
||||
enableHashPruning(enableHashPruning), cntAppliedRules(12),
|
||||
printAndExit(printAndExit) {}
|
||||
|
||||
int Derivator::getNumIntermediateStates() { return cntStates; }
|
||||
|
||||
|
@ -405,6 +406,8 @@ Expr Derivator::mergeMemboundStages(VecExpr stages) {
|
|||
void Derivator::appendCanddiate(const Tensor &tensor, int depth) {
|
||||
// if (!CountRoutineVisitor().match(tensor, 1, 0, 3))
|
||||
// return;
|
||||
if (intermediateStates.size() > 1 && printAndExit)
|
||||
printDerivationRules();
|
||||
|
||||
candidates.emplace_back(tensor, depth);
|
||||
// dbg("!!!!!!!!!!!!!!!Success!!!!!!!!!!!!!!!");
|
||||
|
@ -478,6 +481,7 @@ void Derivator::printStatistics() {
|
|||
printf("#Hashed intermediate states = %lu\n", visited.size());
|
||||
printf("#Iteratos = %d\n", nIteratorNames);
|
||||
printf("#Tensors = %d\n", nTensorNames);
|
||||
printf("#Print and Exit mode = %d\n", printAndExit);
|
||||
}
|
||||
|
||||
void Derivator::setDumpFirstSuccess(const string &_logFnPrefix) {
|
||||
|
@ -490,6 +494,9 @@ void Derivator::printIntermediateStates() {
|
|||
// Skip in NoLog mode
|
||||
if (logMode == LogMode::NoLog)
|
||||
return;
|
||||
if (intermediateStates.size() > 1 && printAndExit)
|
||||
printDerivationRules();
|
||||
|
||||
assert(intermediateStates.size() == ruleStates.size());
|
||||
assert(intermediateStates.size() == ruleMsgs.size());
|
||||
for (size_t i = 0; i < intermediateStates.size(); ++i) {
|
||||
|
@ -503,12 +510,13 @@ void Derivator::printIntermediateStates() {
|
|||
logFnPrefix + to_string(i) + ".expr", msg);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < intermediateStates.size(); ++i) {
|
||||
if (auto cur = as<RangeOpNode>(intermediateStates[i]))
|
||||
if (CheckOOBVisitor().checkRangeOp(cur)) {
|
||||
printf("OOB detected depth=%lu\n", i);
|
||||
}
|
||||
}
|
||||
// FIXME
|
||||
// for (size_t i = 0; i < intermediateStates.size(); ++i) {
|
||||
// if (auto cur = as<RangeOpNode>(intermediateStates[i]))
|
||||
// if (CheckOOBVisitor().checkRangeOp(cur)) {
|
||||
// printf("OOB detected depth=%lu\n", i);
|
||||
// }
|
||||
// }
|
||||
if (logMode == LogMode::DumpFristCandiate) {
|
||||
puts("Serializaiton finished.");
|
||||
exit(0);
|
||||
|
@ -528,4 +536,21 @@ Derivator::PassMode Derivator::getPassMode() { return passMode; }
|
|||
|
||||
Derivator::LogMode Derivator::getLogMode() { return logMode; }
|
||||
|
||||
void Derivator::printDerivationRules() {
|
||||
int cntRules = 0, cntNonGuideRules = 0;
|
||||
bool startGuided = false;
|
||||
std::cout << ruleStates.size() << "rules" << std::endl;
|
||||
for (size_t i = 1; i < ruleStates.size(); ++i) {
|
||||
int ruleId = ruleStates[i][4] - '0';
|
||||
if (ruleId != 4)
|
||||
++cntRules;
|
||||
if (ruleId == 8)
|
||||
startGuided = true;
|
||||
if (!startGuided && ruleId != 4)
|
||||
++cntNonGuideRules;
|
||||
}
|
||||
printf("#Steps w/o converging derivation %d, #Steps w/ converging derivation %d\n", cntRules, cntNonGuideRules);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
} // namespace nnet
|
||||
|
|
|
@ -6,25 +6,35 @@
|
|||
using namespace nnet;
|
||||
using namespace std;
|
||||
|
||||
// clang-format off
|
||||
/* Evaluation bash script
|
||||
# Maxdepth
|
||||
for i in $(seq 1 12); do
|
||||
echo $i
|
||||
NNET_UseHash=1 NNET_MaxDepth=$i ./test_OpSearch # &>
|
||||
out.searchDepthTest.$i.txt done
|
||||
NNET_UseHash=1 NNET_MaxDepth=$i ./test_OpSearch &> out.searchDepthTest.$i.txt done
|
||||
|
||||
NNET_UseHash=0 NNET_MaxDepth=8 ./test_OpSearch # &> out.searchDepthTest.$i.txt
|
||||
NNET_UseHash=1 NNET_MaxDepth=8 ./test_OpSearch # &> out.searchDepthTest.$i.txt
|
||||
# Enable/disable hash
|
||||
NNET_UseHash=0 NNET_MaxDepth=8 ./test_OpSearch &> out.searchDepthTest.$i.txt
|
||||
NNET_UseHash=1 NNET_MaxDepth=8 ./test_OpSearch &> out.searchDepthTest.$i.txt
|
||||
|
||||
# Number of derivation steps
|
||||
for i in Conv3x3 ConvTranspose Conv5x5 G2BMM; do
|
||||
NNET_PrintAndExit=1 NNET_UseHash=1 NNET_MaxDepth=7 ./test_OpSearch --gtest_filter="*$i" > out.steps.$i.txt done
|
||||
*/
|
||||
// clang-format on
|
||||
|
||||
class OpSearch : public ::testing::Test {
|
||||
protected:
|
||||
const int maxDepth = getMaxDepth();
|
||||
const int useHash = getUseHash();
|
||||
const bool printAndExit = getPrintAndExit();
|
||||
// const int maxDepth = 8;
|
||||
// const int useHash = true;
|
||||
const Derivator::LogMode mode = Derivator::LogMode::NoLog;
|
||||
const Derivator::LogMode mode = getPrintAndExit()
|
||||
? Derivator::LogMode::DumpFristCandiate
|
||||
: Derivator::LogMode::NoLog;
|
||||
const Derivator::PassMode passMode = Derivator::PassMode::Full;
|
||||
const bool isRuleBased = false;
|
||||
const bool isRuleBased = getPrintAndExit();
|
||||
|
||||
void SetUp() override {
|
||||
if (maxDepth < 0 || useHash < 0) {
|
||||
|
@ -45,9 +55,16 @@ class OpSearch : public ::testing::Test {
|
|||
return atoi(s);
|
||||
return -1;
|
||||
}
|
||||
|
||||
static bool getPrintAndExit() {
|
||||
if (auto s = getenv("NNET_PrintAndExit"))
|
||||
return atoi(s);
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
|
||||
// TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
|
||||
TEST_F(OpSearch, Conv3x3) {
|
||||
// A[n,h+r,w+s,c]*K[r,s,f,c]
|
||||
int N = 1, H = 7, W = 7, C = 512, F = 512;
|
||||
int R = 3, S = 3;
|
||||
|
@ -72,7 +89,7 @@ TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
|
|||
|
||||
// Derivation
|
||||
Formula conv_3x3_nhwc_rsfc(range, 0);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
|
||||
|
||||
if (isRuleBased) {
|
||||
// Rule-based derivation
|
||||
|
@ -142,13 +159,14 @@ RangeOp buildTConv4x4_NHWF_RSFC(const int N, const int C, const int H,
|
|||
return range0;
|
||||
}
|
||||
|
||||
TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
|
||||
// TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
|
||||
TEST_F(OpSearch, ConvTranspose) {
|
||||
const int N = 16, H = 2, W = 2, C = 256, F = 448;
|
||||
const int R = 4, S = 4;
|
||||
RangeOp range = buildTConv4x4_NHWF_RSFC(N, C, H, W, F, R, S);
|
||||
|
||||
Formula conv_9x9(range, 0);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
|
||||
|
||||
if (isRuleBased) {
|
||||
const vector<int> rules{3, 2, 2, 2, 2, 5};
|
||||
|
@ -165,7 +183,8 @@ TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
|
|||
derivator.printStatistics();
|
||||
}
|
||||
|
||||
TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
|
||||
// TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
|
||||
TEST_F(OpSearch, Conv5x5) {
|
||||
int N = 16, C = 32, H = 224, W = 224, F = 1;
|
||||
int R = 5, S = 5;
|
||||
auto n = make_ref<VarNode>("n");
|
||||
|
@ -187,7 +206,7 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
|
|||
{{c, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA * subK);
|
||||
|
||||
Formula conv_9x9(range, 0);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
|
||||
|
||||
if (isRuleBased) {
|
||||
const vector<int> rules{9, 1, 1, 3, 2, 2, 5, 8, 8, 6, 6};
|
||||
|
@ -204,7 +223,8 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
|
|||
EXPECT_GE(nMatches, 1);
|
||||
}
|
||||
|
||||
TEST_F(OpSearch, G2BMM_RuleBased) {
|
||||
// TEST_F(OpSearch, G2BMM_RuleBased) {
|
||||
TEST_F(OpSearch, G2BMM) {
|
||||
const int n_heads = 8, seq_len = 10000, feat_len = 512;
|
||||
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32, dilation = 4;
|
||||
DEFINE_VAR(b);
|
||||
|
@ -224,7 +244,7 @@ TEST_F(OpSearch, G2BMM_RuleBased) {
|
|||
|
||||
// Derivation: this work without padding check in stage merging
|
||||
Formula dialted_g2bmm(range, 0);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode);
|
||||
Derivator derivator(maxDepth, useHash, mode, passMode, printAndExit);
|
||||
|
||||
if (isRuleBased) {
|
||||
const vector<int> rules{1, 7, 7, 2, 8, 6, 6};
|
||||
|
@ -240,4 +260,4 @@ TEST_F(OpSearch, G2BMM_RuleBased) {
|
|||
EXPECT_GE(nMatches, 1);
|
||||
// derivator.print();
|
||||
derivator.printStatistics();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue