InfiniTensor/test/nnet/test_g2bmm.cc

84 lines
3.0 KiB
C++
Raw Normal View History

2022-08-08 16:02:07 +08:00
#include "nnet/Visitor/CountRoutineVisitor.h"
#include "nnet/derivator.h"
#include "nnet/expr.h"
#include "nnet/test.h"
#include "gtest/gtest.h"
using namespace nnet;
using namespace std;
TEST(GBMM, RuleBased) {
const int n_heads = 8, seq_len = 10000, feat_len = 512;
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32, dilation = 4;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(n);
DEFINE_VAR(w);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, 2 * W + 1}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
vector<int>{0, dilation * W, 0});
auto subA = makeSubscript(A, {b, m, w});
// auto subB = makeSubscript(B, {b, m + dilation * (w - W), n});
auto subB = makeSubscript(B, {b, m + dilation * w - dilation * W, n});
auto range = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {n, {0, K}}},
{{w, {0, 2 * W + 1}}}, subA * subB);
dbg(range);
// Derivation: this work without padding check in stage merging
Formula dialted_g2bmm(range, 0);
Derivator derivator;
bool isRuleBased = false;
if (isRuleBased) {
const vector<int> rules{1, 7, 7, 2, 8, 6, 6};
derivator.setDumpFirstSuccess("GBMM_");
derivator.ruleBasedDFS(dialted_g2bmm, 0, rules);
} else {
derivator.search(dialted_g2bmm, 0);
}
ASSERT_GE(derivator.getNumCandidates(), 1);
int nMatches = matchExprResult(derivator, "test/nnet/log/gbmm/GBMM_9.expr");
2022-08-08 16:02:07 +08:00
EXPECT_GE(nMatches, 1);
derivator.print();
derivator.printStatistics();
}
TEST(G2BMM, RuleBased) {
const int n_heads = 8, seq_len = 10000, feat_len = 512;
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32, dilation = 4;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, M, K}),
vector<int>{0, dilation * W, 0});
auto subA = makeSubscript(A, {b, m, k});
auto subB = makeSubscript(B, {b, m + dilation * (w - W), k});
auto range =
makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}, {w, {0, 2 * W + 1}}},
{{k, {0, K}}}, subA * subB);
// Derivation: this work without padding check in stage merging
Formula dialted_g2bmm(range, 0);
Derivator derivator;
bool isRuleBased = false;
if (isRuleBased) {
const vector<int> rules{1, 7, 7, 2, 8, 6, 6};
derivator.setDumpFirstSuccess("G2BMM_");
derivator.ruleBasedDFS(dialted_g2bmm, 0, rules);
} else {
derivator.search(dialted_g2bmm, 0);
}
ASSERT_GE(derivator.getNumCandidates(), 1);
Fix NNet tests after migration (#27) * Fix: interpreter ``` 4 - readlog (Failed) 8 - test_TConv2gemm (Failed) 11 - test_conv2conv (Failed) 12 - test_conv2gemm (Failed) 15 - test_g2bmm (Failed) 16 - test_guidedDLT (Subprocess aborted) 22 - test_mergeStage (Subprocess aborted) ``` * Exclude readlog from ctest * Fix: change the path of logs ``` 85% tests passed, 4 tests failed out of 27 Total Test time (real) = 100.69 sec The following tests FAILED: 10 - test_conv2conv (Timeout) 11 - test_conv2gemm (Timeout) 15 - test_guidedDLT (Subprocess aborted) 21 - test_mergeStage (Subprocess aborted) Errors while running CTest ``` - test_conv2conv 38529 ms total - test_conv2gemm 37098 ms total * Fix: test_mergeStage * Fix: test_guidedDLT ``` Start 1: test_graph 1/27 Test #1: test_graph ....................... Passed 0.05 sec Start 2: test_hash 2/27 Test #2: test_hash ........................ Passed 0.02 sec Start 3: test_conv 3/27 Test #3: test_conv ........................ Passed 4.98 sec Start 4: test_Interpreter 4/27 Test #4: test_Interpreter ................. Passed 6.30 sec Start 5: test_OpSearch 5/27 Test #5: test_OpSearch .................... Passed 0.02 sec Start 6: test_Rule2VariableMerging 6/27 Test #6: test_Rule2VariableMerging ........ Passed 0.03 sec Start 7: test_TConv2gemm 7/27 Test #7: test_TConv2gemm .................. Passed 29.45 sec Start 8: test_as_tvm 8/27 Test #8: test_as_tvm ...................... Passed 0.02 sec Start 9: test_compareFormulas 9/27 Test #9: test_compareFormulas ............. Passed 0.02 sec Start 10: test_conv2conv 10/27 Test #10: test_conv2conv ................... Passed 36.55 sec Start 11: test_conv2gemm 11/27 Test #11: test_conv2gemm ................... Passed 39.70 sec Start 12: test_dlt 12/27 Test #12: test_dlt ......................... Passed 0.03 sec Start 13: test_exprHash 13/27 Test #13: test_exprHash .................... Passed 0.02 sec Start 14: test_g2bmm 14/27 Test #14: test_g2bmm ....................... Passed 0.16 sec Start 15: test_guidedDLT 15/27 Test #15: test_guidedDLT ................... Passed 0.07 sec Start 16: test_matchConv 16/27 Test #16: test_matchConv ................... Passed 0.02 sec Start 17: test_matchElementWise 17/27 Test #17: test_matchElementWise ............ Passed 0.03 sec Start 18: test_matchMatmul 18/27 Test #18: test_matchMatmul ................. Passed 0.02 sec Start 19: test_matchReshape 19/27 Test #19: test_matchReshape ................ Passed 0.02 sec Start 20: test_memboundOp 20/27 Test #20: test_memboundOp .................. Passed 0.02 sec Start 21: test_mergeStage 21/27 Test #21: test_mergeStage .................. Passed 0.02 sec Start 22: test_oobChecker 22/27 Test #22: test_oobChecker .................. Passed 0.02 sec Start 23: test_rangeMagnify 23/27 Test #23: test_rangeMagnify ................ Passed 0.02 sec Start 24: test_relaxation 24/27 Test #24: test_relaxation .................. Passed 0.02 sec Start 25: test_serializer 25/27 Test #25: test_serializer .................. Passed 0.03 sec Start 26: test_simplify 26/27 Test #26: test_simplify .................... Passed 0.02 sec Start 27: test_subset 27/27 Test #27: test_subset ...................... Passed 0.01 sec 100% tests passed, 0 tests failed out of 27 Total Test time (real) = 117.72 sec ``` * Fix: format * Replace nnet:Ref with infini::Ref ``` Start 1: test_graph 1/27 Test 1: test_graph ....................... Passed 0.02 sec Start 2: test_hash 2/27 Test 2: test_hash ........................ Passed 0.02 sec Start 3: test_conv 3/27 Test 3: test_conv ........................ Passed 4.45 sec Start 4: test_Interpreter 4/27 Test 4: test_Interpreter ................. Passed 4.37 sec Start 5: test_OpSearch 5/27 Test 5: test_OpSearch .................... Passed 0.02 sec Start 6: test_Rule2VariableMerging 6/27 Test 6: test_Rule2VariableMerging ........ Passed 0.02 sec Start 7: test_TConv2gemm 7/27 Test 7: test_TConv2gemm .................. Passed 23.40 sec Start 8: test_as_tvm 8/27 Test 8: test_as_tvm ...................... Passed 0.02 sec Start 9: test_compareFormulas 9/27 Test 9: test_compareFormulas ............. Passed 0.01 sec Start 10: test_conv2conv 10/27 Test 10: test_conv2conv ................... Passed 32.28 sec Start 11: test_conv2gemm 11/27 Test 11: test_conv2gemm ................... Passed 29.41 sec Start 12: test_dlt 12/27 Test 12: test_dlt ......................... Passed 0.02 sec Start 13: test_exprHash 13/27 Test 13: test_exprHash .................... Passed 0.01 sec Start 14: test_g2bmm 14/27 Test 14: test_g2bmm ....................... Passed 0.14 sec Start 15: test_guidedDLT 15/27 Test 15: test_guidedDLT ................... Passed 0.06 sec Start 16: test_matchConv 16/27 Test 16: test_matchConv ................... Passed 0.02 sec Start 17: test_matchElementWise 17/27 Test 17: test_matchElementWise ............ Passed 0.02 sec Start 18: test_matchMatmul 18/27 Test 18: test_matchMatmul ................. Passed 0.02 sec Start 19: test_matchReshape 19/27 Test 19: test_matchReshape ................ Passed 0.01 sec Start 20: test_memboundOp 20/27 Test 20: test_memboundOp .................. Passed 0.02 sec Start 21: test_mergeStage 21/27 Test 21: test_mergeStage .................. Passed 0.01 sec Start 22: test_oobChecker 22/27 Test 22: test_oobChecker .................. Passed 0.01 sec Start 23: test_rangeMagnify 23/27 Test 23: test_rangeMagnify ................ Passed 0.01 sec Start 24: test_relaxation 24/27 Test 24: test_relaxation .................. Passed 0.01 sec Start 25: test_serializer 25/27 Test 25: test_serializer .................. Passed 0.02 sec Start 26: test_simplify 26/27 Test 26: test_simplify .................... Passed 0.01 sec Start 27: test_subset 27/27 Test 27: test_subset ...................... Passed 0.00 sec 100% tests passed, 0 tests failed out of 27 Total Test time (real) = 94.47 sec ``` * Relax time limit for CPU conv ``` Start 1: test_graph 1/29 Test 1: test_graph ....................... Passed 0.02 sec Start 2: test_hash 2/29 Test 2: test_hash ........................ Passed 0.02 sec Start 3: test_conv 3/29 Test 3: test_conv ........................ Passed 4.47 sec Start 4: test_matmul 4/29 Test 4: test_matmul ...................... Passed 2.61 sec Start 5: test_pooling 5/29 Test 5: test_pooling ..................... Passed 2.57 sec Start 6: test_Interpreter 6/29 Test 6: test_Interpreter ................. Passed 4.35 sec Start 7: test_OpSearch 7/29 Test 7: test_OpSearch .................... Passed 0.02 sec Start 8: test_Rule2VariableMerging 8/29 Test 8: test_Rule2VariableMerging ........ Passed 0.02 sec Start 9: test_TConv2gemm 9/29 Test 9: test_TConv2gemm .................. Passed 23.32 sec Start 10: test_as_tvm 10/29 Test 10: test_as_tvm ...................... Passed 0.02 sec Start 11: test_compareFormulas 11/29 Test 11: test_compareFormulas ............. Passed 0.02 sec Start 12: test_conv2conv 12/29 Test 12: test_conv2conv ................... Passed 32.12 sec Start 13: test_conv2gemm 13/29 Test 13: test_conv2gemm ................... Passed 30.59 sec Start 14: test_dlt 14/29 Test 14: test_dlt ......................... Passed 0.02 sec Start 15: test_exprHash 15/29 Test 15: test_exprHash .................... Passed 0.01 sec Start 16: test_g2bmm 16/29 Test 16: test_g2bmm ....................... Passed 0.14 sec Start 17: test_guidedDLT 17/29 Test 17: test_guidedDLT ................... Passed 0.07 sec Start 18: test_matchConv 18/29 Test 18: test_matchConv ................... Passed 0.02 sec Start 19: test_matchElementWise 19/29 Test 19: test_matchElementWise ............ Passed 0.02 sec Start 20: test_matchMatmul 20/29 Test 20: test_matchMatmul ................. Passed 0.02 sec Start 21: test_matchReshape 21/29 Test 21: test_matchReshape ................ Passed 0.02 sec Start 22: test_memboundOp 22/29 Test 22: test_memboundOp .................. Passed 0.02 sec Start 23: test_mergeStage 23/29 Test 23: test_mergeStage .................. Passed 0.01 sec Start 24: test_oobChecker 24/29 Test 24: test_oobChecker .................. Passed 0.02 sec Start 25: test_rangeMagnify 25/29 Test 25: test_rangeMagnify ................ Passed 0.02 sec Start 26: test_relaxation 26/29 Test 26: test_relaxation .................. Passed 0.02 sec Start 27: test_serializer 27/29 Test 27: test_serializer .................. Passed 0.03 sec Start 28: test_simplify 28/29 Test 28: test_simplify .................... Passed 0.02 sec Start 29: test_subset 29/29 Test 29: test_subset ...................... Passed 0.00 sec 100% tests passed, 0 tests failed out of 29 Total Test time (real) = 100.65 sec ``` * Remove out-of-date tests Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
2022-09-13 15:17:22 +08:00
int nMatches =
matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr");
2022-08-08 16:02:07 +08:00
EXPECT_GE(nMatches, 1);
derivator.print();
derivator.printStatistics();
}