InfiniTensor/test/nnet/test_TConv2gemm.cc

392 lines
15 KiB
C++
Raw Normal View History

2022-08-08 16:02:07 +08:00
#include "nnet/Visitor/CountRoutineVisitor.h"
#include "nnet/Visitor/GetTensorsVisitor.h"
#include "nnet/Visitor/Interpreter.h"
#include "nnet/derivator.h"
#include "nnet/expr.h"
#include "nnet/test.h"
#include "gtest/gtest.h"
using namespace nnet;
using namespace std;
TEST(Conv2conv, TConv4x4_NHWC_innerStage_RuleBased) {
const int N = 1, H = 2, W = 2, C = 256, F = 448;
const int R = 4, S = 4;
const int OH = 2 * H, OW = 2 * W;
DEFINE_VAR(n, c, f, r, s, x1, x2, y1, y2);
DEFINE_VAR(i2, i4);
// dilation * (kernel_size - 1) - padding
int padding = 1 * (R - 1) - 1;
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
vector<int>{0, padding, padding, 0});
auto K = make_ref<TensorNode>("K", vector<int>({F, R, S, C}));
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK = makeSubscript(
K, {f, -2 * r + (-1) * x2 + (R - 1), -2 * s + (-1) * y2 + (S - 1), c});
auto range = makeRangeOperator(
{
{n, {0, N}},
{c, {0, C}},
{x1, {0, OH / 2 + 1}},
{x2, {0, 2}},
{y1, {0, OW / 2 + 1}},
{y2, {0, 2}},
},
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
dbg(range);
const vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 90};
Formula conv_9x9(range, 0);
Derivator derivator;
derivator.ruleBasedDFS(
conv_9x9, 0, rules,
{{1, {x1, r}}, {2, {y1, s}}, {3, {x2, i2}}, {4, {y2, i4}}});
EXPECT_EQ(derivator.getSearchedMaxDepth(), ((int)rules.size()));
ASSERT_GE(derivator.getNumCandidates(), 1);
const auto &formula = derivator.getCandidates().front();
EXPECT_TRUE(CountRoutineVisitor().match(formula.root, 1, 0, 3));
derivator.print();
}
TEST(Conv2conv, TConv4x4_NHWC_RuleBased) {
const int N = 1, H = 2, W = 2, C = 256, F = 448;
const int R = 4, S = 4;
const int OH = 2 * H, OW = 2 * W;
DEFINE_VAR(n, c, h, w, f, r, s);
DEFINE_VAR(x1, x2, y1, y2);
// dilation * (kernel_size - 1) - padding
int padding = 1 * (R - 1) - 1;
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
vector<int>{0, padding, padding, 0});
auto K = make_ref<TensorNode>("K", vector<int>({F, R, S, C}));
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK = makeSubscript(
K, {f, -2 * r + (-1) * x2 + (R - 1), -2 * s + (-1) * y2 + (S - 1), c});
// auto range =
// makeRangeOperator({{n, {0, N}}, {c, {0, H}}, {w, {0, W}}, {f, {0,
// F}}},
// {{f, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA *
// subK);
auto range = makeRangeOperator(
{
{n, {0, N}},
{x1, {0, OH / 2 + 1}},
{x2, {0, 2}},
{y1, {0, OW / 2 + 1}},
{y2, {0, 2}},
{c, {0, C}},
},
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
auto subOuter = makeSubscript(
range, {n, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2, c});
auto outerRange = makeRangeOperator(
{
{n, {0, N}},
{h, {0, OH}},
{w, {0, OW}},
{c, {0, C}},
},
{}, subOuter);
dbg(outerRange);
// Derivation: this work without padding check in stage merging
// const vector<int> rules{1, 1, 3, 2, 2, 5, 2, 2, 6, 4, 4, 4, 4, 6};
// Before Guided DLT seperated from rule2VarMerging
// const vector<int> rules{1, 1, 3, 2, 2, 5, 2, 2, 6, 6};
const vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90};
Formula conv_9x9(outerRange, 0);
Derivator derivator;
// derivator.ruleBasedDFS(conv_9x9, 0, rules,
// {{1, {"x1", "r"}},
// {2, {"y1", "s"}},
// {3, {"x2", "i2"}},
// {4, {"y2", "i4"}}});
derivator.ruleBasedDFS(conv_9x9, 0, rules);
EXPECT_EQ(derivator.getSearchedMaxDepth(), ((int)rules.size()));
ASSERT_GE(derivator.getNumCandidates(), 1);
const auto &formula = derivator.getCandidates().front();
EXPECT_TRUE(CountRoutineVisitor().match(formula.root, 1, 0, 3));
derivator.print();
}
TEST(Conv2conv, TConv4x4_BS16_NHWC_RuleBased) {
const int N = 16, H = 2, W = 2, C = 256, F = 448;
const int R = 4, S = 4;
const int OH = 2 * H, OW = 2 * W;
DEFINE_VAR(n, c, h, w, f, r, s);
DEFINE_VAR(x1, x2, y1, y2);
// dilation * (kernel_size - 1) - padding
int padding = 1 * (R - 1) - 1;
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
vector<int>{0, padding, padding, 0});
auto K = make_ref<TensorNode>("K", vector<int>({F, R, S, C}));
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK = makeSubscript(
K, {f, -2 * r + (-1) * x2 + (R - 1), -2 * s + (-1) * y2 + (S - 1), c});
// auto range =
// makeRangeOperator({{n, {0, N}}, {c, {0, H}}, {w, {0, W}}, {f, {0,
// F}}},
// {{f, {0, C}}, {r, {0, R}}, {s, {0, S}}}, subA *
// subK);
auto range = makeRangeOperator(
{
{n, {0, N}},
{x1, {0, OH / 2 + 1}},
{x2, {0, 2}},
{y1, {0, OW / 2 + 1}},
{y2, {0, 2}},
{c, {0, C}},
},
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
auto subOuter = makeSubscript(
range, {n, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2, c});
auto outerRange = makeRangeOperator(
{
{n, {0, N}},
{h, {0, OH}},
{w, {0, OW}},
{c, {0, C}},
},
{}, subOuter);
dbg(outerRange);
// Derivation: this work without padding check in stage merging
// const vector<int> rules{1, 1, 3, 2, 2, 5, 2, 2, 6, 4, 4, 4, 4, 6};
// Before Guided DLT seperated from rule2VarMerging
// const vector<int> rules{1, 1, 3, 2, 2, 5, 2, 2, 6, 6};
const vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90};
Formula conv_9x9(outerRange, 0);
Derivator derivator;
// derivator.ruleBasedDFS(conv_9x9, 0, rules,
// {{1, {"x1", "r"}},
// {2, {"y1", "s"}},
// {3, {"x2", "i2"}},
// {4, {"y2", "i4"}}});
derivator.ruleBasedDFS(conv_9x9, 0, rules);
EXPECT_EQ(derivator.getSearchedMaxDepth(), ((int)rules.size()));
ASSERT_GE(derivator.getNumCandidates(), 1);
const auto &formula = derivator.getCandidates().front();
EXPECT_TRUE(CountRoutineVisitor().match(formula.root, 1, 0, 3));
derivator.print();
}
// Warn: F is the number of input channels, which is inversed compared with
// normal Conv.
// Our data layout: NHWF -> NHWC, FRSC
// Pytorch data layout: NFHW -> NCHW, FCRS
RangeOp buildTConv4x4_NHWF_FRSC(const int N, const int C, const int H,
const int W, const int F, const int R,
const int S) {
assert(R == 4 && S == 4);
const int OH = 2 * H, OW = 2 * W;
DEFINE_VAR(n, c, h, w, f, r, s);
DEFINE_VAR(x1, x2, y1, y2, i2, i4);
// dilation * (kernel_size - 1) - padding
int padding = 1 * (R - 1) - 1;
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
vector<int>{0, padding, padding, 0});
auto K = make_ref<TensorNode>("K", vector<int>({F, R, S, C}));
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK =
makeSubscript(K, {f, (R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, c});
// x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2
auto range1 = makeRangeOperator(
{
{n, {0, N}},
{c, {0, C}},
{x1, {0, OH / 2 + 1}},
{x2, {0, 2}},
{y1, {0, OW / 2 + 1}},
{y2, {0, 2}},
},
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
dbg(range1);
auto sub0 = makeSubscript(
range1, {n, c, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2});
auto range0 = makeRangeOperator(
{{n, {0, N}}, {h, {0, OH}}, {w, {0, OW}}, {c, {0, C}}}, {}, sub0);
return range0;
}
// Correct input expression
// Warn: F is the number of input channels, which is inversed compared with
// normal Conv.
// Our data layout: NHWF -> NHWC, RSFC
// Pytorch data layout: NFHW -> NCHW, FCRS
RangeOp buildTConv4x4_NHWF_RSFC(const int N, const int C, const int H,
const int W, const int F, const int R,
const int S) {
assert(R == 4 && S == 4);
const int OH = 2 * H, OW = 2 * W;
DEFINE_VAR(n, c, h, w, f, r, s);
DEFINE_VAR(x1, x2, y1, y2, i2, i4);
// dilation * (kernel_size - 1) - padding
int padding = 1 * (R - 1) - 1;
auto A = make_ref<TensorNode>("A", vector<int>({N, H, W, F}),
vector<int>{0, padding, padding, 0});
auto K = make_ref<TensorNode>("K", vector<int>({R, S, F, C}));
auto subA = makeSubscript(A, {n, x1 + r - 1, y1 + s - 1, f});
auto subK =
makeSubscript(K, {(R - 2) - 2 * r + x2, (S - 2) - 2 * s + y2, f, c});
// x1=(h+1)//2, x2=(h+1)%2, y1=(w+1)//2
auto range1 = makeRangeOperator(
{
{n, {0, N}},
{c, {0, C}},
{x1, {0, OH / 2 + 1}},
{x2, {0, 2}},
{y1, {0, OW / 2 + 1}},
{y2, {0, 2}},
},
{{f, {0, F}}, {r, {0, R / 2}}, {s, {0, S / 2}}}, subA * subK);
dbg(range1);
auto sub0 = makeSubscript(
range1, {n, c, (h + 1) / 2, (h + 1) % 2, (w + 1) / 2, (w + 1) % 2});
auto range0 = makeRangeOperator(
{{n, {0, N}}, {h, {0, OH}}, {w, {0, OW}}, {c, {0, C}}}, {}, sub0);
return range0;
}
TEST(TConv2gemm, TConv4x4_NHWF_FRSC_correctness_of_input_expr) {
const int N = 1, H = 2, W = 2, C = 2, F = 3;
const int R = 4, S = 4;
RangeOp range0 = buildTConv4x4_NHWF_FRSC(N, C, H, W, F, R, S);
auto ans0 = Interpreter(range0).interpretAllOutput(range0);
// Pytorch results
// torch.conv_transpose2d(X, K, stride=2, padding=1, dilation=1)
// X, K is NFHW and FCRS
vector<int> ans1 = {190, 193, 740, 755, 770, 785, 592, 604,
992, 1016, 2704, 2770, 2836, 2902, 1832, 1874,
1184, 1208, 3232, 3298, 3364, 3430, 2168, 2210,
1114, 1135, 2660, 2711, 2762, 2813, 1624, 1654};
ASSERT_EQ(ans0.size(), ans1.size());
for (size_t i = 0; i < ans0.size(); ++i)
EXPECT_EQ(ans0[i], ans1[i]);
}
ssize_t getOffset(vector<ssize_t> index, vector<int> shape) {
ssize_t ret = index[0];
for (size_t i = 1; i < index.size(); ++i)
ret = ret * shape[i] + index[i];
return ret;
}
TEST(TConv2gemm, TConv4x4_NHWF_RSFC_correctness_of_input_expr) {
const int N = 1, H = 2, W = 2, C = 2, F = 3;
const int R = 4, S = 4;
RangeOp range0 = buildTConv4x4_NHWF_RSFC(N, C, H, W, F, R, S);
Interpreter::Inputs inputs;
for (const auto &[name, tensor] : GetTensorsVisitor().get(range0)) {
auto data = make_ref<vector<int>>(tensor->getSize());
if (name == "A") {
for (ssize_t i = 0; i < tensor->getSize(); i++)
data->operator[](i) = i;
} else if (name == "K") {
for (ssize_t r = 0; r < R; r++)
for (ssize_t s = 0; s < S; s++)
for (ssize_t f = 0; f < F; f++)
for (ssize_t c = 0; c < C; c++) {
ssize_t index =
getOffset({r, s, f, c}, {R, S, F, C});
ssize_t num = getOffset({f, r, s, c}, {F, R, S, C});
data->operator[](index) = num;
}
} else
assert(0);
inputs.emplace(name, data);
}
auto ans0 = Interpreter(inputs).interpretAllOutput(range0);
// Pytorch results
// torch.conv_transpose2d(X, K, stride=2, padding=1, dilation=1)
// X, K is NFHW and FCRS
vector<int> ans1 = {190, 193, 740, 755, 770, 785, 592, 604,
992, 1016, 2704, 2770, 2836, 2902, 1832, 1874,
1184, 1208, 3232, 3298, 3364, 3430, 2168, 2210,
1114, 1135, 2660, 2711, 2762, 2813, 1624, 1654};
ASSERT_EQ(ans0.size(), ans1.size());
for (size_t i = 0; i < ans0.size(); ++i)
EXPECT_EQ(ans0[i], ans1[i]);
}
// TODO: Test after passing RSFC
// TEST(TConv2gemm, TConv4x4_NHWF_FRSC_search) {
// const int N = 1, H = 2, W = 2, C = 256, F = 448;
// const int R = 4, S = 4;
// RangeOp range = buildTConv4x4_NHWF_FRSC(N, C, H, W, F, R, S);
// const vector<int> rules{3, 2, 2, 2, 2, 5, 8, 8, 6, 90};
// Formula conv_9x9(range, 0);
// Derivator derivator;
// // derivator.ruleBasedDFS(
// // conv_9x9, 0, rules,
// // {{1, {x1, r}}, {2, {y1, s}}, {3, {x2, i2}}, {4, {y2, i4}}});
// derivator.dfs(conv_9x9, 0);
// // EXPECT_EQ(derivator.getSearchedMaxDepth(), ((int)rules.size()));
// ASSERT_GE(derivator.getNumCandidates(), 1);
// const auto &formula = derivator.getCandidates().front();
// EXPECT_TRUE(CountRoutineVisitor().match(formula.root, 1, 0, 3));
// derivator.print();
// }
TEST(TConv2gemm, TConv4x4_NHWF_RSFC_search) {
const int N = 16, H = 2, W = 2, C = 256, F = 448;
const int R = 4, S = 4;
RangeOp range = buildTConv4x4_NHWF_RSFC(N, C, H, W, F, R, S);
Formula conv_9x9(range, 0);
Derivator derivator;
bool isRuleBased = false;
if (isRuleBased) {
const vector<int> rules{3, 2, 2, 2, 2, 5};
derivator.setDumpFirstSuccess("TConv4x4_NHWF_RSFC_");
derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true);
} else
derivator.search(conv_9x9, 0);
ASSERT_GE(derivator.getNumCandidates(), 1);
derivator.print();
// for (const auto &f : derivator.getCandidates()) {
// dbg(CountRoutineVisitor().count(f.root));
// }
int nMatches = matchExprResult(
Fix NNet tests after migration (#27) * Fix: interpreter ``` 4 - readlog (Failed) 8 - test_TConv2gemm (Failed) 11 - test_conv2conv (Failed) 12 - test_conv2gemm (Failed) 15 - test_g2bmm (Failed) 16 - test_guidedDLT (Subprocess aborted) 22 - test_mergeStage (Subprocess aborted) ``` * Exclude readlog from ctest * Fix: change the path of logs ``` 85% tests passed, 4 tests failed out of 27 Total Test time (real) = 100.69 sec The following tests FAILED: 10 - test_conv2conv (Timeout) 11 - test_conv2gemm (Timeout) 15 - test_guidedDLT (Subprocess aborted) 21 - test_mergeStage (Subprocess aborted) Errors while running CTest ``` - test_conv2conv 38529 ms total - test_conv2gemm 37098 ms total * Fix: test_mergeStage * Fix: test_guidedDLT ``` Start 1: test_graph 1/27 Test #1: test_graph ....................... Passed 0.05 sec Start 2: test_hash 2/27 Test #2: test_hash ........................ Passed 0.02 sec Start 3: test_conv 3/27 Test #3: test_conv ........................ Passed 4.98 sec Start 4: test_Interpreter 4/27 Test #4: test_Interpreter ................. Passed 6.30 sec Start 5: test_OpSearch 5/27 Test #5: test_OpSearch .................... Passed 0.02 sec Start 6: test_Rule2VariableMerging 6/27 Test #6: test_Rule2VariableMerging ........ Passed 0.03 sec Start 7: test_TConv2gemm 7/27 Test #7: test_TConv2gemm .................. Passed 29.45 sec Start 8: test_as_tvm 8/27 Test #8: test_as_tvm ...................... Passed 0.02 sec Start 9: test_compareFormulas 9/27 Test #9: test_compareFormulas ............. Passed 0.02 sec Start 10: test_conv2conv 10/27 Test #10: test_conv2conv ................... Passed 36.55 sec Start 11: test_conv2gemm 11/27 Test #11: test_conv2gemm ................... Passed 39.70 sec Start 12: test_dlt 12/27 Test #12: test_dlt ......................... Passed 0.03 sec Start 13: test_exprHash 13/27 Test #13: test_exprHash .................... Passed 0.02 sec Start 14: test_g2bmm 14/27 Test #14: test_g2bmm ....................... Passed 0.16 sec Start 15: test_guidedDLT 15/27 Test #15: test_guidedDLT ................... Passed 0.07 sec Start 16: test_matchConv 16/27 Test #16: test_matchConv ................... Passed 0.02 sec Start 17: test_matchElementWise 17/27 Test #17: test_matchElementWise ............ Passed 0.03 sec Start 18: test_matchMatmul 18/27 Test #18: test_matchMatmul ................. Passed 0.02 sec Start 19: test_matchReshape 19/27 Test #19: test_matchReshape ................ Passed 0.02 sec Start 20: test_memboundOp 20/27 Test #20: test_memboundOp .................. Passed 0.02 sec Start 21: test_mergeStage 21/27 Test #21: test_mergeStage .................. Passed 0.02 sec Start 22: test_oobChecker 22/27 Test #22: test_oobChecker .................. Passed 0.02 sec Start 23: test_rangeMagnify 23/27 Test #23: test_rangeMagnify ................ Passed 0.02 sec Start 24: test_relaxation 24/27 Test #24: test_relaxation .................. Passed 0.02 sec Start 25: test_serializer 25/27 Test #25: test_serializer .................. Passed 0.03 sec Start 26: test_simplify 26/27 Test #26: test_simplify .................... Passed 0.02 sec Start 27: test_subset 27/27 Test #27: test_subset ...................... Passed 0.01 sec 100% tests passed, 0 tests failed out of 27 Total Test time (real) = 117.72 sec ``` * Fix: format * Replace nnet:Ref with infini::Ref ``` Start 1: test_graph 1/27 Test 1: test_graph ....................... Passed 0.02 sec Start 2: test_hash 2/27 Test 2: test_hash ........................ Passed 0.02 sec Start 3: test_conv 3/27 Test 3: test_conv ........................ Passed 4.45 sec Start 4: test_Interpreter 4/27 Test 4: test_Interpreter ................. Passed 4.37 sec Start 5: test_OpSearch 5/27 Test 5: test_OpSearch .................... Passed 0.02 sec Start 6: test_Rule2VariableMerging 6/27 Test 6: test_Rule2VariableMerging ........ Passed 0.02 sec Start 7: test_TConv2gemm 7/27 Test 7: test_TConv2gemm .................. Passed 23.40 sec Start 8: test_as_tvm 8/27 Test 8: test_as_tvm ...................... Passed 0.02 sec Start 9: test_compareFormulas 9/27 Test 9: test_compareFormulas ............. Passed 0.01 sec Start 10: test_conv2conv 10/27 Test 10: test_conv2conv ................... Passed 32.28 sec Start 11: test_conv2gemm 11/27 Test 11: test_conv2gemm ................... Passed 29.41 sec Start 12: test_dlt 12/27 Test 12: test_dlt ......................... Passed 0.02 sec Start 13: test_exprHash 13/27 Test 13: test_exprHash .................... Passed 0.01 sec Start 14: test_g2bmm 14/27 Test 14: test_g2bmm ....................... Passed 0.14 sec Start 15: test_guidedDLT 15/27 Test 15: test_guidedDLT ................... Passed 0.06 sec Start 16: test_matchConv 16/27 Test 16: test_matchConv ................... Passed 0.02 sec Start 17: test_matchElementWise 17/27 Test 17: test_matchElementWise ............ Passed 0.02 sec Start 18: test_matchMatmul 18/27 Test 18: test_matchMatmul ................. Passed 0.02 sec Start 19: test_matchReshape 19/27 Test 19: test_matchReshape ................ Passed 0.01 sec Start 20: test_memboundOp 20/27 Test 20: test_memboundOp .................. Passed 0.02 sec Start 21: test_mergeStage 21/27 Test 21: test_mergeStage .................. Passed 0.01 sec Start 22: test_oobChecker 22/27 Test 22: test_oobChecker .................. Passed 0.01 sec Start 23: test_rangeMagnify 23/27 Test 23: test_rangeMagnify ................ Passed 0.01 sec Start 24: test_relaxation 24/27 Test 24: test_relaxation .................. Passed 0.01 sec Start 25: test_serializer 25/27 Test 25: test_serializer .................. Passed 0.02 sec Start 26: test_simplify 26/27 Test 26: test_simplify .................... Passed 0.01 sec Start 27: test_subset 27/27 Test 27: test_subset ...................... Passed 0.00 sec 100% tests passed, 0 tests failed out of 27 Total Test time (real) = 94.47 sec ``` * Relax time limit for CPU conv ``` Start 1: test_graph 1/29 Test 1: test_graph ....................... Passed 0.02 sec Start 2: test_hash 2/29 Test 2: test_hash ........................ Passed 0.02 sec Start 3: test_conv 3/29 Test 3: test_conv ........................ Passed 4.47 sec Start 4: test_matmul 4/29 Test 4: test_matmul ...................... Passed 2.61 sec Start 5: test_pooling 5/29 Test 5: test_pooling ..................... Passed 2.57 sec Start 6: test_Interpreter 6/29 Test 6: test_Interpreter ................. Passed 4.35 sec Start 7: test_OpSearch 7/29 Test 7: test_OpSearch .................... Passed 0.02 sec Start 8: test_Rule2VariableMerging 8/29 Test 8: test_Rule2VariableMerging ........ Passed 0.02 sec Start 9: test_TConv2gemm 9/29 Test 9: test_TConv2gemm .................. Passed 23.32 sec Start 10: test_as_tvm 10/29 Test 10: test_as_tvm ...................... Passed 0.02 sec Start 11: test_compareFormulas 11/29 Test 11: test_compareFormulas ............. Passed 0.02 sec Start 12: test_conv2conv 12/29 Test 12: test_conv2conv ................... Passed 32.12 sec Start 13: test_conv2gemm 13/29 Test 13: test_conv2gemm ................... Passed 30.59 sec Start 14: test_dlt 14/29 Test 14: test_dlt ......................... Passed 0.02 sec Start 15: test_exprHash 15/29 Test 15: test_exprHash .................... Passed 0.01 sec Start 16: test_g2bmm 16/29 Test 16: test_g2bmm ....................... Passed 0.14 sec Start 17: test_guidedDLT 17/29 Test 17: test_guidedDLT ................... Passed 0.07 sec Start 18: test_matchConv 18/29 Test 18: test_matchConv ................... Passed 0.02 sec Start 19: test_matchElementWise 19/29 Test 19: test_matchElementWise ............ Passed 0.02 sec Start 20: test_matchMatmul 20/29 Test 20: test_matchMatmul ................. Passed 0.02 sec Start 21: test_matchReshape 21/29 Test 21: test_matchReshape ................ Passed 0.02 sec Start 22: test_memboundOp 22/29 Test 22: test_memboundOp .................. Passed 0.02 sec Start 23: test_mergeStage 23/29 Test 23: test_mergeStage .................. Passed 0.01 sec Start 24: test_oobChecker 24/29 Test 24: test_oobChecker .................. Passed 0.02 sec Start 25: test_rangeMagnify 25/29 Test 25: test_rangeMagnify ................ Passed 0.02 sec Start 26: test_relaxation 26/29 Test 26: test_relaxation .................. Passed 0.02 sec Start 27: test_serializer 27/29 Test 27: test_serializer .................. Passed 0.03 sec Start 28: test_simplify 28/29 Test 28: test_simplify .................... Passed 0.02 sec Start 29: test_subset 29/29 Test 29: test_subset ...................... Passed 0.00 sec 100% tests passed, 0 tests failed out of 29 Total Test time (real) = 100.65 sec ``` * Remove out-of-date tests Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
2022-09-13 15:17:22 +08:00
derivator,
"test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr");
2022-08-08 16:02:07 +08:00
EXPECT_GE(nMatches, 1);
derivator.printStatistics();
}
TEST(TConv2gemm, TConv4x4_NHWF_FRSC_CheckDerivationCorrectness_log) {
const string fnPrefix =
"test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_";
2022-08-08 16:02:07 +08:00
EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 11));
}
// TODO: correct ConvTransPattern
TEST(Conv2conv, InfoGAN_ConvTranspose_3_OOB_Test) {
// ConvTranspose_3 in InfoGAN
const int n = 1, c = 256, h = 2, w = 2, f = 448, r = 4, s = 4;
int padding = 1 * (r - 1) - 1;
const auto A = nnet::makeTensor("A", {n, h, w, f},
std::vector<int>{0, padding, padding, 0});
const auto K = nnet::makeTensor("K", {f, c, r, s});
auto expr = ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s);
dbg(expr);
Derivator derivator;
derivator.checkOOB(as<RangeOpNode>(expr));
}