forked from jiuyuan/InfiniTensor
add rule for dilated conv
This commit is contained in:
parent
d051460c23
commit
225a42f22d
|
@ -60,7 +60,7 @@ class NMutator : public Mutator {
|
|||
// Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
|
||||
Graph transformConvtransposed1x1(Operator _op);
|
||||
// Graph transformConvtransposed(Operator op);
|
||||
// Graph transformDialtedConv(Operator op);
|
||||
Graph transformDialtedConv(Operator _op);
|
||||
// Graph transformConv1x1(Operator op);
|
||||
// Graph transformConv1xk(Operator op);
|
||||
};
|
||||
|
|
|
@ -81,10 +81,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
// if (infini::Graph g = transformDialtedConv(computeOps[0])) {
|
||||
// out_graphs.emplace_back(g);
|
||||
// return;
|
||||
// }
|
||||
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
// // if (infini::Graph g = transformConv1x1(computeOps[0])) {
|
||||
// // Graph graph = new Graph(g->getOperators());
|
||||
// // out_graphs.emplace_back(graph);
|
||||
|
@ -476,45 +476,34 @@ double NMutator::memboundTime(const Shape &dims) {
|
|||
// return new Graph(g->getOperators());
|
||||
// }
|
||||
|
||||
// Graph NMutator::transformDialtedConv(Operator op) {
|
||||
// if (auto convOp = dynamic_cast<ConvOp *>(op)) {
|
||||
// if (convOp->getPh() == convOp->getDh() && convOp->getSh() == 1 &&
|
||||
// convOp->getDh() > 1 && convOp->getDh() == convOp->getDw()) {
|
||||
// const int d = convOp->getDh();
|
||||
// assert(convOp->getInputs()[0]->getDims()[2] % d == 0);
|
||||
// auto g = new infini::Graph();
|
||||
// auto inputDims = convOp->getInputs(0)->getDims();
|
||||
// auto weightDims = convOp->getInputs(1)->getDims();
|
||||
// auto outputDims = convOp->getOutput()->getDims();
|
||||
// auto newA = g->tensor({inputDims[0] * d * d, inputDims[1],
|
||||
// inputDims[2] / d, inputDims[3] / d});
|
||||
// // auto newW = g->tensor(
|
||||
// // {weightDims[0] * weightDims[1] * weightDims[3],
|
||||
// // weightDims[2]});
|
||||
// auto newO =
|
||||
// g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
|
||||
// weightDims[0] * weightDims[1] * weightDims[3]});
|
||||
// g->membound(
|
||||
// {convOp->getInputs(0)}, {newA}, {}, nullptr,
|
||||
// memboundTime(convOp->getInputs(0)->size() + newA->size()),
|
||||
// "DConv Tranpose Input");
|
||||
// // g->membound({convOp->getInputs(1)}, {newW}, {}, nullptr, 0,
|
||||
// // "Tranpose Weight");
|
||||
// g->conv(newA, convOp->getInputs(1), newO, 1, 1, 1, 1);
|
||||
// g->reshape(newO, convOp->getOutput());
|
||||
// dbg(newO->size(), convOp->getOutput()->size());
|
||||
// assert(newO->size() == convOp->getOutput()->size());
|
||||
// g->membound(
|
||||
// {newO}, {convOp->getOutput()}, {}, nullptr,
|
||||
// memboundTime(newO->size() + convOp->getOutput()->size()),
|
||||
// "DConv Tranpose Output");
|
||||
// g->updateConnection();
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// return graph;
|
||||
// }
|
||||
// }
|
||||
// return nullptr;
|
||||
// }
|
||||
Graph NMutator::transformDialtedConv(Operator _op) {
|
||||
auto op = as<ConvObj>(_op);
|
||||
if (!op)
|
||||
return {};
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
if (dh <= 1 && dw <= 1)
|
||||
return {};
|
||||
const auto &A = op->getInputs()[0];
|
||||
const auto &W = op->getInputs()[1];
|
||||
const auto &O = op->getOutput();
|
||||
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const Shape inputDims = A->getDims();
|
||||
const Shape weightDims = W->getDims();
|
||||
const Shape outputDims = O->getDims();
|
||||
const DataType dtype = A->getDType();
|
||||
IT_ASSERT_TODO(dh == dw && ph == dh && pw == dw);
|
||||
IT_ASSERT_TODO(tie(sh, sw) == tuple(1, 1));
|
||||
IT_ASSERT_TODO(h % dh == 0 && w % dw == 0);
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto newA = g->addTensor({dh * dw * n, c, h / dh, h / dw}, dtype);
|
||||
// HACH: without transpose
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
|
||||
Tensor newO =
|
||||
g->addOp<ConvObj>(newA, W, nullptr, 1, 1, sh, sw, 1, 1)->getOutput();
|
||||
// HACH: without transpose
|
||||
g->addOpWithOutputs<ReshapeObj>(newO, g->cloneTensor(O), O->getDims());
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConvtransposed1x1(Operator _op) {
|
||||
auto op = as<ConvTransposed2dNHWCObj>(_op);
|
||||
|
|
Loading…
Reference in New Issue