diff --git a/include/nnet/nmutator.h b/include/nnet/nmutator.h index e3442fd3..c67cb96c 100644 --- a/include/nnet/nmutator.h +++ b/include/nnet/nmutator.h @@ -60,7 +60,7 @@ class NMutator : public Mutator { // Graph fuseHetConv(nnet::Expr expr, Graph in_graph); Graph transformConvtransposed1x1(Operator _op); // Graph transformConvtransposed(Operator op); - // Graph transformDialtedConv(Operator op); + Graph transformDialtedConv(Operator _op); // Graph transformConv1x1(Operator op); // Graph transformConv1xk(Operator op); }; diff --git a/src/nnet/nmutator.cc b/src/nnet/nmutator.cc index 4c76a7d8..67706859 100644 --- a/src/nnet/nmutator.cc +++ b/src/nnet/nmutator.cc @@ -81,10 +81,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector &out_graphs) { out_graphs.emplace_back(g); return; } - // if (infini::Graph g = transformDialtedConv(computeOps[0])) { - // out_graphs.emplace_back(g); - // return; - // } + if (infini::Graph g = transformDialtedConv(computeOps[0])) { + out_graphs.emplace_back(g); + return; + } // // if (infini::Graph g = transformConv1x1(computeOps[0])) { // // Graph graph = new Graph(g->getOperators()); // // out_graphs.emplace_back(graph); @@ -476,45 +476,34 @@ double NMutator::memboundTime(const Shape &dims) { // return new Graph(g->getOperators()); // } -// Graph NMutator::transformDialtedConv(Operator op) { -// if (auto convOp = dynamic_cast(op)) { -// if (convOp->getPh() == convOp->getDh() && convOp->getSh() == 1 && -// convOp->getDh() > 1 && convOp->getDh() == convOp->getDw()) { -// const int d = convOp->getDh(); -// assert(convOp->getInputs()[0]->getDims()[2] % d == 0); -// auto g = new infini::Graph(); -// auto inputDims = convOp->getInputs(0)->getDims(); -// auto weightDims = convOp->getInputs(1)->getDims(); -// auto outputDims = convOp->getOutput()->getDims(); -// auto newA = g->tensor({inputDims[0] * d * d, inputDims[1], -// inputDims[2] / d, inputDims[3] / d}); -// // auto newW = g->tensor( -// // {weightDims[0] * weightDims[1] * weightDims[3], -// // weightDims[2]}); -// auto newO = -// g->tensor({inputDims[0] * inputDims[1] * inputDims[2], -// weightDims[0] * weightDims[1] * weightDims[3]}); -// g->membound( -// {convOp->getInputs(0)}, {newA}, {}, nullptr, -// memboundTime(convOp->getInputs(0)->size() + newA->size()), -// "DConv Tranpose Input"); -// // g->membound({convOp->getInputs(1)}, {newW}, {}, nullptr, 0, -// // "Tranpose Weight"); -// g->conv(newA, convOp->getInputs(1), newO, 1, 1, 1, 1); -// g->reshape(newO, convOp->getOutput()); -// dbg(newO->size(), convOp->getOutput()->size()); -// assert(newO->size() == convOp->getOutput()->size()); -// g->membound( -// {newO}, {convOp->getOutput()}, {}, nullptr, -// memboundTime(newO->size() + convOp->getOutput()->size()), -// "DConv Tranpose Output"); -// g->updateConnection(); -// Graph graph = new Graph(g->getOperators()); -// return graph; -// } -// } -// return nullptr; -// } +Graph NMutator::transformDialtedConv(Operator _op) { + auto op = as(_op); + if (!op) + return {}; + const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + if (dh <= 1 && dw <= 1) + return {}; + const auto &A = op->getInputs()[0]; + const auto &W = op->getInputs()[1]; + const auto &O = op->getOutput(); + const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS(); + const Shape inputDims = A->getDims(); + const Shape weightDims = W->getDims(); + const Shape outputDims = O->getDims(); + const DataType dtype = A->getDType(); + IT_ASSERT_TODO(dh == dw && ph == dh && pw == dw); + IT_ASSERT_TODO(tie(sh, sw) == tuple(1, 1)); + IT_ASSERT_TODO(h % dh == 0 && w % dw == 0); + auto g = make_ref(runtime); + auto newA = g->addTensor({dh * dw * n, c, h / dh, h / dw}, dtype); + // HACH: without transpose + g->addOpWithOutputs(g->cloneTensor(A), newA, newA->getDims()); + Tensor newO = + g->addOp(newA, W, nullptr, 1, 1, sh, sw, 1, 1)->getOutput(); + // HACH: without transpose + g->addOpWithOutputs(newO, g->cloneTensor(O), O->getDims()); + return g; +} Graph NMutator::transformConvtransposed1x1(Operator _op) { auto op = as(_op);