add rule for dilated conv

This commit is contained in:
whjthu 2023-04-21 23:40:45 +08:00
parent d051460c23
commit 225a42f22d
2 changed files with 33 additions and 44 deletions

View File

@ -60,7 +60,7 @@ class NMutator : public Mutator {
// Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
Graph transformConvtransposed1x1(Operator _op);
// Graph transformConvtransposed(Operator op);
// Graph transformDialtedConv(Operator op);
Graph transformDialtedConv(Operator _op);
// Graph transformConv1x1(Operator op);
// Graph transformConv1xk(Operator op);
};

View File

@ -81,10 +81,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
out_graphs.emplace_back(g);
return;
}
// if (infini::Graph g = transformDialtedConv(computeOps[0])) {
// out_graphs.emplace_back(g);
// return;
// }
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
out_graphs.emplace_back(g);
return;
}
// // if (infini::Graph g = transformConv1x1(computeOps[0])) {
// // Graph graph = new Graph(g->getOperators());
// // out_graphs.emplace_back(graph);
@ -476,45 +476,34 @@ double NMutator::memboundTime(const Shape &dims) {
// return new Graph(g->getOperators());
// }
// Graph NMutator::transformDialtedConv(Operator op) {
// if (auto convOp = dynamic_cast<ConvOp *>(op)) {
// if (convOp->getPh() == convOp->getDh() && convOp->getSh() == 1 &&
// convOp->getDh() > 1 && convOp->getDh() == convOp->getDw()) {
// const int d = convOp->getDh();
// assert(convOp->getInputs()[0]->getDims()[2] % d == 0);
// auto g = new infini::Graph();
// auto inputDims = convOp->getInputs(0)->getDims();
// auto weightDims = convOp->getInputs(1)->getDims();
// auto outputDims = convOp->getOutput()->getDims();
// auto newA = g->tensor({inputDims[0] * d * d, inputDims[1],
// inputDims[2] / d, inputDims[3] / d});
// // auto newW = g->tensor(
// // {weightDims[0] * weightDims[1] * weightDims[3],
// // weightDims[2]});
// auto newO =
// g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
// weightDims[0] * weightDims[1] * weightDims[3]});
// g->membound(
// {convOp->getInputs(0)}, {newA}, {}, nullptr,
// memboundTime(convOp->getInputs(0)->size() + newA->size()),
// "DConv Tranpose Input");
// // g->membound({convOp->getInputs(1)}, {newW}, {}, nullptr, 0,
// // "Tranpose Weight");
// g->conv(newA, convOp->getInputs(1), newO, 1, 1, 1, 1);
// g->reshape(newO, convOp->getOutput());
// dbg(newO->size(), convOp->getOutput()->size());
// assert(newO->size() == convOp->getOutput()->size());
// g->membound(
// {newO}, {convOp->getOutput()}, {}, nullptr,
// memboundTime(newO->size() + convOp->getOutput()->size()),
// "DConv Tranpose Output");
// g->updateConnection();
// Graph graph = new Graph(g->getOperators());
// return graph;
// }
// }
// return nullptr;
// }
Graph NMutator::transformDialtedConv(Operator _op) {
auto op = as<ConvObj>(_op);
if (!op)
return {};
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
if (dh <= 1 && dw <= 1)
return {};
const auto &A = op->getInputs()[0];
const auto &W = op->getInputs()[1];
const auto &O = op->getOutput();
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
const Shape inputDims = A->getDims();
const Shape weightDims = W->getDims();
const Shape outputDims = O->getDims();
const DataType dtype = A->getDType();
IT_ASSERT_TODO(dh == dw && ph == dh && pw == dw);
IT_ASSERT_TODO(tie(sh, sw) == tuple(1, 1));
IT_ASSERT_TODO(h % dh == 0 && w % dw == 0);
auto g = make_ref<GraphObj>(runtime);
auto newA = g->addTensor({dh * dw * n, c, h / dh, h / dw}, dtype);
// HACH: without transpose
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
Tensor newO =
g->addOp<ConvObj>(newA, W, nullptr, 1, 1, sh, sw, 1, 1)->getOutput();
// HACH: without transpose
g->addOpWithOutputs<ReshapeObj>(newO, g->cloneTensor(O), O->getDims());
return g;
}
Graph NMutator::transformConvtransposed1x1(Operator _op) {
auto op = as<ConvTransposed2dNHWCObj>(_op);