Add: Conv1x1 rule

This commit is contained in:
Liyan Zheng 2023-04-21 23:21:04 +08:00
parent d051460c23
commit 16a8c5dce5
3 changed files with 41 additions and 64 deletions

View File

@ -61,7 +61,7 @@ class NMutator : public Mutator {
Graph transformConvtransposed1x1(Operator _op);
// Graph transformConvtransposed(Operator op);
// Graph transformDialtedConv(Operator op);
// Graph transformConv1x1(Operator op);
Graph transformConv1x1(Operator op);
// Graph transformConv1xk(Operator op);
};

View File

@ -85,11 +85,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
// out_graphs.emplace_back(g);
// return;
// }
// // if (infini::Graph g = transformConv1x1(computeOps[0])) {
// // Graph graph = new Graph(g->getOperators());
// // out_graphs.emplace_back(graph);
// // return;
// // }
if (Graph g = transformConv1x1(computeOps[0])) {
out_graphs.emplace_back(g);
return;
}
// // if (infini::Graph g = transformConv1xk(computeOps[0])) {
// // Graph graph = new Graph(g->getOperators());
// // out_graphs.emplace_back(graph);
@ -100,7 +99,6 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
if (opSet.count(computeOps[0]->getOpType()) == 0)
return;
auto expr = opToExpression(computeOps[0]);
if (!expr)
return;
@ -373,7 +371,13 @@ infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
g->addOpWithOutputs<ConvObj>(A, K, output, ph, pw, sh, sw, dh, dw);
} else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
assert(op->getInputs().size() == 1);
// dbg(op, op->getExpr());
// TODO: For a single input channel conv, it can be transformed into
// vec X vec ---> matrix --reduce--> result
// This transformation only introduce membound Ops and can have a
// wrong estimated execution time, so we skip it now.
if (op->getInputs().size() != 1)
return nullptr;
nnet::MatchReshapeVisitor matchReshapeVisitor;
// If this routine only change the shape, translate it to a Reshape
if (matchReshapeVisitor(op->getExpr())) {
@ -598,61 +602,34 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
// return graph;
// }
// Graph NMutator::transformTConv1x1(Operator op) {
// if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
// if (tconvOp->getPh() == 0 && tconvOp->getSh() == 1) {
// auto g = new infini::Graph();
// auto inputDims = tconvOp->getInputs(0)->getDims();
// auto weightDims = tconvOp->getInputs(1)->getDims();
// auto outputDims = tconvOp->getOutput()->getDims();
// auto newA = g->tensor(
// {inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
// auto newW = g->tensor(
// {weightDims[0] * weightDims[1] * weightDims[3],
// weightDims[2]});
// auto newO =
// g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
// weightDims[0] * weightDims[1] * weightDims[3]});
// g->reshape(tconvOp->getInputs(0), newA);
// g->reshape(tconvOp->getInputs(1), newW);
// g->matmul(newA, newW, newO, 0, 1);
// g->reshape(newO, tconvOp->getOutput());
// g->updateConnection();
// Graph graph = new Graph(g->getOperators());
// return graph;
// }
// }
// return nullptr;
// }
// Graph NMutator::transformConv1x1(Operator op) {
// auto convOp = dynamic_cast<ConvOp *>(op);
// if (!convOp)
// return nullptr;
// if (convOp->getPh() == 0 && convOp->getSh() == 1 &&
// convOp->getInputs()[1]->getDims()[2] == 1 &&
// convOp->getInputs()[1]->getDims()[3] == 1) {
// // Transpose is requrired for BS>1
// // if (convOp->getInputs()[0]->getDims()[0] == 1) {
// auto g = new infini::Graph();
// auto inputDims = convOp->getInputs(0)->getDims();
// auto weightDims = convOp->getInputs(1)->getDims();
// auto outputDims = convOp->getOutput()->getDims();
// auto newA = g->tensor(
// {inputDims[1], inputDims[0] * inputDims[2] * inputDims[3]});
// auto newW = g->tensor({weightDims[0], weightDims[1]});
// auto newO = g->tensor(
// {weightDims[0], inputDims[0] * inputDims[2] * inputDims[3]});
// g->reshape(convOp->getInputs(0), newA);
// g->reshape(convOp->getInputs(1), newW);
// g->matmul(newW, newA, newO, 0, 0);
// g->reshape(newO, convOp->getOutput());
// g->updateConnection();
// Graph graph = new Graph(g->getOperators());
// return graph;
// }
// return nullptr;
// }
Graph NMutator::transformConv1x1(Operator _op) {
auto op = as<ConvObj>(_op);
if (!op)
return nullptr;
Shape shapeA = op->getInputs(0)->getDims();
Shape shapeW = op->getInputs(1)->getDims();
// TODO: support batch size > 1
if (shapeA[0] != 1)
return nullptr;
if (op->getPh() == 0 && op->getSh() == 1 && shapeW[2] == 1 &&
shapeW[3] == 1) {
auto g = make_ref<GraphObj>(runtime);
auto A =
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
vector{shapeA[1], shapeA[0] * shapeA[2] *
shapeA[3]}) // [C, N*H*W]
->getOutput();
auto B = g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(1)), nullptr,
vector{shapeW[0], shapeW[1]}) // [F, C]
->getOutput();
auto O =
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, N*H*W]
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
op->getOutput()->getDims());
return g;
}
return nullptr;
}
// Graph NMutator::transformConv1xk(Operator op) {
// auto convOp = dynamic_cast<ConvOp *>(op);

View File

@ -122,6 +122,6 @@ if __name__ == "__main__":
[3, 2, 2, 5, 8, 8, 6, 90])
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
save_onnx(g, f"optimized_{name}.onnx")
save_onnx(g, f"opt_{name}.onnx")
verify_graphs(runtime, original_g, g)
run_and_evaluate(runtime, g)