Add: show conv2gemm derivation

This commit is contained in:
Liyan Zheng 2023-11-10 22:49:07 +08:00
parent 29071ddcac
commit 295450e5f4
5 changed files with 82 additions and 4 deletions

3
.gitignore vendored
View File

@ -31,8 +31,7 @@
*.out *.out
*.app *.app
build/ *build*/
build_debug/
.vscode/ .vscode/

2
run.sh Normal file
View File

@ -0,0 +1,2 @@
./test_mutator --gtest_filter="NMutator.Conv3x3_to_gemm"
./test_conv2gemm --gtest_filter="*.NCHW_FCRS_ruleBased"

View File

@ -413,7 +413,8 @@ void Derivator::appendCanddiate(const Tensor &tensor, int depth) {
// dbg("!!!!!!!!!!!!!!!Success!!!!!!!!!!!!!!!"); // dbg("!!!!!!!!!!!!!!!Success!!!!!!!!!!!!!!!");
if (enableEquivalenceCheck) if (enableEquivalenceCheck)
checkDerivationEquivalence(); checkDerivationEquivalence();
// printIntermediateStates(); if (candidates.size() > 2)
printIntermediateStates();
} }
bool Derivator::checkOOB(const RangeOp &rangeOp, bool halt) { bool Derivator::checkOOB(const RangeOp &rangeOp, bool halt) {

View File

@ -24,7 +24,7 @@ TEST(Conv2gemm, NCHW_FCRS_ruleBased) {
// Derivation // Derivation
Formula conv_1x1_nhwc_fc(range, 0); Formula conv_1x1_nhwc_fc(range, 0);
Derivator derivator(12); Derivator derivator(12, true, nnet::Derivator::LogMode::DumpFristCandiate);
// const vector<int> rules = {3, 2, 2, 5, 2, 2, 6, 6}; // const vector<int> rules = {3, 2, 2, 5, 2, 2, 6, 6};
const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90}; const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
// derivator.ruleBasedDFS(conv_1x1_nhwc_fc, 0, rules); // derivator.ruleBasedDFS(conv_1x1_nhwc_fc, 0, rules);

View File

@ -132,6 +132,82 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) {
} }
} }
TEST(NMutator, Conv3x3_to_gemm) {
const bool useMutatorDirectly = true;
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph original_graph = make_ref<GraphObj>(runtime);
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
Graph gCpu = make_ref<GraphObj>(cpu);
const int n = 1, c = 32, h = 64, w = 64, f = 32, r = 3, s = 3;
// // Minimum config for test
// const int n = 1, c = 1, h = 2, w = 2, f = 1, r = 4, s = 4;
// const int n = 1, c = 2, h = 2, w = 2, f = 2, r = 4, s = 4;
auto i0 = original_graph->addTensor({n, c, h, w}, DataType::Float32,
TensorType::Input);
auto w0 = original_graph->addTensor({f, c, r, s}, DataType::Float32,
TensorType::Initialized);
original_graph->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 1, 1, 1, 1);
auto mutator = make_ref<NMutator>(NMutator::Mode::RuleBased,
vector<int>{3, 2, 2, 5, 8, 8, 6, 90});
// // Translate OP to membound without derivation
// mutator->setToNaiveMembound();
vector<Graph> bestGraphs;
if (useMutatorDirectly) { // Use mutator results
bestGraphs = mutator->run(original_graph);
} else { // Use search engine results
SearchEngine searchEngine(runtime, mutator);
bestGraphs.emplace_back(searchEngine.run(original_graph));
}
original_graph->dataMalloc();
map<UidBaseType, Tensor> fuidToInputTensor;
for (auto t : original_graph->getInputs()) {
EXPECT_EQ(fuidToInputTensor.count(t->getFuid()), 0);
fuidToInputTensor[t->getFuid()] = t;
}
dbg(original_graph);
std::cout << "# bestGraphs = " << bestGraphs.size() << std::endl;
for (size_t i = 1; i < bestGraphs.size(); i++) {
auto bestGraphCpu = bestGraphs[i];
auto bestGraph =
make_ref<GraphObj>(runtime, bestGraphCpu->getOperators());
dbg(bestGraph);
break;
// auto gen = RandomGenerator(0.1, 0.1, i);
// bestGraph->dataMalloc();
// // Initialize inputs with random data
// for (auto t : g->getInputs()) {
// t->setData(gen);
// }
// for (auto t : bestGraph->getInputs()) {
// t->copyData(fuidToInputTensor[t->getFuid()]);
// }
// // Initialize outputs with zeros
// for (auto t : g->getOutputs()) {
// t->setData(ZeroGenerator());
// }
// for (auto t : bestGraph->getOutputs()) {
// t->setData(ZeroGenerator());
// }
// runtime->run(bestGraph, true); // Tune kernels
// runtime->run(g);
// runtime->run(bestGraph, false); // Execute transfomraed graph
// auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
// auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
// dbg(runtime->getPerfTime(bestGraph, true));
// EXPECT_TRUE(go0->equalData(bgo0, 1e-4));
}
}
// TEST(Mutator, Conv9x9) { // TEST(Mutator, Conv9x9) {
// auto g = new tpm::Graph(); // auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 1, 224, 224}); // auto i0 = g->tensor({1, 1, 224, 224});