forked from jiuyuan/InfiniTensor
243 lines
8.2 KiB
C++
243 lines
8.2 KiB
C++
#include "nnet/Visitor/CountRoutineVisitor.h"
|
|
#include "nnet/derivator.h"
|
|
#include "nnet/expr.h"
|
|
#include "nnet/test.h"
|
|
#include "gtest/gtest.h"
|
|
using namespace nnet;
|
|
using namespace std;
|
|
|
|
/* Evaluation bash script
|
|
for i in $(seq 1 12); do
|
|
echo $i
|
|
NNET_UseHash=1 NNET_MaxDepth=$i ./test_OpSearch # &>
|
|
out.searchDepthTest.$i.txt done
|
|
|
|
NNET_UseHash=0 NNET_MaxDepth=8 ./test_OpSearch # &> out.searchDepthTest.$i.txt
|
|
NNET_UseHash=1 NNET_MaxDepth=8 ./test_OpSearch # &> out.searchDepthTest.$i.txt
|
|
*/
|
|
|
|
class OpSearch : public ::testing::Test {
|
|
protected:
|
|
const int maxDepth = getMaxDepth();
|
|
const int useHash = getUseHash();
|
|
// const int maxDepth = 8;
|
|
// const int useHash = true;
|
|
const Derivator::LogMode mode = Derivator::LogMode::NoLog;
|
|
const Derivator::PassMode passMode = Derivator::PassMode::Full;
|
|
const bool isRuleBased = false;
|
|
|
|
void SetUp() override {
|
|
if (maxDepth < 0 || useHash < 0) {
|
|
GTEST_SKIP() << "Skipping OpSearch since NNET_MaxDepth or "
|
|
"NNET_UseHash are not specifed.\n";
|
|
}
|
|
}
|
|
|
|
private:
|
|
static int getMaxDepth() {
|
|
if (auto s = getenv("NNET_MaxDepth"))
|
|
return atoi(s);
|
|
return -1;
|
|
}
|
|
|
|
static bool getUseHash() {
|
|
if (auto s = getenv("NNET_UseHash"))
|
|
return atoi(s);
|
|
return -1;
|
|
}
|
|
};
|
|
|
|
TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) {
|
|
// A[n,h+r,w+s,c]*K[r,s,f,c]
|
|
int N = 1, H = 7, W = 7, C = 512, F = 512;
|
|
int R = 3, S = 3;
|
|
auto n = make_ref<VarNode>("n");
|
|
auto c = make_ref<VarNode>("c");
|
|
auto h = make_ref<VarNode>("h");
|
|
auto w = make_ref<VarNode>("w");
|
|
auto f = make_ref<VarNode>("f");
|
|
auto r = make_ref<VarNode>("r");
|
|
auto s = make_ref<VarNode>("s");
|
|
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, C}),
|
|
vector<int>{0, R / 2, S / 2, 0});
|
|
// auto K = make_ref<TensorNode>("K", vector<int>({R, S, F, C}));
|
|
auto K = make_ref<TensorNode>("K", vector<int>({R, S, F, C}));
|
|
|
|
auto subA = makeSubscript(A, {n, h + r - R / 2, w + s - S / 2, c});
|
|
auto subK = makeSubscript(K, {r, s, f, c});
|
|
|
|
auto range =
|
|
makeRangeOperator({{n, {0, N}}, {h, {0, H}}, {w, {0, W}}, {f, {0, F}}},
|
|
{{c, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA * subK);
|
|
|
|
// Derivation
|
|
Formula conv_3x3_nhwc_rsfc(range, 0);
|
|
Derivator derivator(maxDepth, useHash, mode, passMode);
|
|
|
|
if (isRuleBased) {
|
|
// Rule-based derivation
|
|
const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
|
|
derivator.setDumpFirstSuccess("Conv2gemm_NCHW_FCRS_");
|
|
derivator.ruleBasedDFS(conv_3x3_nhwc_rsfc, 0, rules);
|
|
} else {
|
|
derivator.search(conv_3x3_nhwc_rsfc, 0);
|
|
}
|
|
|
|
EXPECT_GE(derivator.getNumCandidates(), 1);
|
|
int nMatches = matchExprResult(
|
|
derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr");
|
|
EXPECT_GE(nMatches, 1);
|
|
// derivator.print();
|
|
derivator.printStatistics();
|
|
}
|
|
|
|
// Warn: F is the number of input channels, which is inversed compared with
|
|
// normal Conv.
|
|
// Our data layout: NHWF -> NHWC, RSFC
|
|
// Pytorch data layout: NFHW -> NCHW, FCRS
|
|
RangeOp buildTConv4x4_NHWF_RSFC(const int N, const int C, const int H,
|
|
const int W, const int F, const int R,
|
|
const int S) {
|
|
assert(R == 4 && S == 4);
|
|
const int OH = 2 * H, OW = 2 * W;
|
|
DEFINE_VAR(n);
|
|
DEFINE_VAR(c);
|
|
DEFINE_VAR(f);
|
|
DEFINE_VAR(r);
|
|
DEFINE_VAR(s);
|
|
DEFINE_VAR(x1);
|
|
DEFINE_VAR(x2);
|
|
DEFINE_VAR(y1);
|
|
DEFINE_VAR(y2);
|
|
DEFINE_VAR(i2);
|
|
DEFINE_VAR(i4);
|
|
DEFINE_VAR(h);
|
|
DEFINE_VAR(w);
|
|
// dilation * (kernel_size - 1) - padding
|
|
int padding = 1 * (R - 1) - 1;
|
|
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
|
|
vector<int>{0, padding, padding, 0});
|
|
auto K = make_ref<TensorNode>("K", vector<int>({R, S, F, C}));
|
|
|
|
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
|
|
auto subK =
|
|
makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c});
|
|
// x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2
|
|
|
|
auto range1 = makeRangeOperator(
|
|
{
|
|
{n, {0, N}},
|
|
{c, {0, C}},
|
|
{x1, {0, OH / 2 + 1}},
|
|
{x2, {0, 2}},
|
|
{y1, {0, OW / 2 + 1}},
|
|
{y2, {0, 2}},
|
|
},
|
|
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
|
|
dbg(range1);
|
|
auto sub0 = makeSubscript(
|
|
range1, {n, c, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2});
|
|
auto range0 = makeRangeOperator(
|
|
{{n, {0, N}}, {h, {0, OH}}, {w, {0, OW}}, {c, {0, C}}}, {}, sub0);
|
|
return range0;
|
|
}
|
|
|
|
TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) {
|
|
const int N = 16, H = 2, W = 2, C = 256, F = 448;
|
|
const int R = 4, S = 4;
|
|
RangeOp range = buildTConv4x4_NHWF_RSFC(N, C, H, W, F, R, S);
|
|
|
|
Formula conv_9x9(range, 0);
|
|
Derivator derivator(maxDepth, useHash, mode, passMode);
|
|
|
|
if (isRuleBased) {
|
|
const vector<int> rules{3, 2, 2, 2, 2, 5};
|
|
derivator.setDumpFirstSuccess("TConv4x4_NHWF_RSFC_");
|
|
derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true);
|
|
} else
|
|
derivator.search(conv_9x9, 0);
|
|
|
|
EXPECT_GE(derivator.getNumCandidates(), 1);
|
|
int nMatches = matchExprResult(
|
|
derivator,
|
|
"../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr");
|
|
EXPECT_GE(nMatches, 1);
|
|
derivator.printStatistics();
|
|
}
|
|
|
|
TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) {
|
|
int N = 16, C = 32, H = 224, W = 224, F = 1;
|
|
int R = 5, S = 5;
|
|
auto n = make_ref<VarNode>("n");
|
|
auto c = make_ref<VarNode>("c");
|
|
auto h = make_ref<VarNode>("h");
|
|
auto w = make_ref<VarNode>("w");
|
|
auto f = make_ref<VarNode>("f");
|
|
auto r = make_ref<VarNode>("r");
|
|
auto s = make_ref<VarNode>("s");
|
|
auto A = make_ref<TensorNode>("A", vector<int>({N, C, H, W}),
|
|
vector<int>{0, 0, R / 2, S / 2});
|
|
auto K = make_ref<TensorNode>("K", vector<int>({F, C, R, S}));
|
|
|
|
auto subA = makeSubscript(A, {n, c, h + r - R / 2, w + s - S / 2});
|
|
auto subK = makeSubscript(K, {f, c, r, s});
|
|
|
|
auto range =
|
|
makeRangeOperator({{n, {0, N}}, {f, {0, F}}, {h, {0, H}}, {w, {0, W}}},
|
|
{{c, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA * subK);
|
|
|
|
Formula conv_9x9(range, 0);
|
|
Derivator derivator(maxDepth, useHash, mode, passMode);
|
|
|
|
if (isRuleBased) {
|
|
const vector<int> rules{9, 1, 1, 3, 2, 2, 5, 8, 8, 6, 6};
|
|
derivator.setDumpFirstSuccess("Conv2conv_5x5_NCHW_FCRS_");
|
|
derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true);
|
|
} else
|
|
derivator.search(conv_9x9, 0);
|
|
|
|
int nMatches = matchExprResult(
|
|
derivator,
|
|
"../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr");
|
|
// derivator.print();
|
|
derivator.printStatistics();
|
|
EXPECT_GE(nMatches, 1);
|
|
}
|
|
|
|
TEST_F(OpSearch, G2BMM_RuleBased) {
|
|
const int n_heads = 8, seq_len = 10000, feat_len = 512;
|
|
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32, dilation = 4;
|
|
DEFINE_VAR(b);
|
|
DEFINE_VAR(m);
|
|
DEFINE_VAR(w);
|
|
DEFINE_VAR(k);
|
|
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
|
|
vector<int>{0, 0, 0});
|
|
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
|
|
vector<int>{0, dilation * W, 0});
|
|
|
|
auto subA = makeSubscript(A, {b, m, k});
|
|
auto subB = makeSubscript(B, {b, m + dilation * (w - W), k});
|
|
auto range =
|
|
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {0, 2 * W + 1}}},
|
|
{{k, {0, K}}}, subA * subB);
|
|
|
|
// Derivation: this work without padding check in stage merging
|
|
Formula dialted_g2bmm(range, 0);
|
|
Derivator derivator(maxDepth, useHash, mode, passMode);
|
|
|
|
if (isRuleBased) {
|
|
const vector<int> rules{1, 7, 7, 2, 8, 6, 6};
|
|
derivator.setDumpFirstSuccess("G2BMM_");
|
|
derivator.ruleBasedDFS(dialted_g2bmm, 0, rules);
|
|
} else {
|
|
derivator.search(dialted_g2bmm, 0);
|
|
}
|
|
|
|
EXPECT_GE(derivator.getNumCandidates(), 1);
|
|
int nMatches =
|
|
matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr");
|
|
EXPECT_GE(nMatches, 1);
|
|
// derivator.print();
|
|
derivator.printStatistics();
|
|
} |