forked from jiuyuan/InfiniTensor
Add: Conv1x1 rule
This commit is contained in:
parent
d051460c23
commit
16a8c5dce5
|
@ -61,7 +61,7 @@ class NMutator : public Mutator {
|
|||
Graph transformConvtransposed1x1(Operator _op);
|
||||
// Graph transformConvtransposed(Operator op);
|
||||
// Graph transformDialtedConv(Operator op);
|
||||
// Graph transformConv1x1(Operator op);
|
||||
Graph transformConv1x1(Operator op);
|
||||
// Graph transformConv1xk(Operator op);
|
||||
};
|
||||
|
||||
|
|
|
@ -85,11 +85,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
// out_graphs.emplace_back(g);
|
||||
// return;
|
||||
// }
|
||||
// // if (infini::Graph g = transformConv1x1(computeOps[0])) {
|
||||
// // Graph graph = new Graph(g->getOperators());
|
||||
// // out_graphs.emplace_back(graph);
|
||||
// // return;
|
||||
// // }
|
||||
if (Graph g = transformConv1x1(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
// // if (infini::Graph g = transformConv1xk(computeOps[0])) {
|
||||
// // Graph graph = new Graph(g->getOperators());
|
||||
// // out_graphs.emplace_back(graph);
|
||||
|
@ -100,7 +99,6 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
if (opSet.count(computeOps[0]->getOpType()) == 0)
|
||||
return;
|
||||
|
||||
|
||||
auto expr = opToExpression(computeOps[0]);
|
||||
if (!expr)
|
||||
return;
|
||||
|
@ -373,7 +371,13 @@ infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
|
|||
const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
|
||||
g->addOpWithOutputs<ConvObj>(A, K, output, ph, pw, sh, sw, dh, dw);
|
||||
} else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
|
||||
assert(op->getInputs().size() == 1);
|
||||
// dbg(op, op->getExpr());
|
||||
// TODO: For a single input channel conv, it can be transformed into
|
||||
// vec X vec ---> matrix --reduce--> result
|
||||
// This transformation only introduce membound Ops and can have a
|
||||
// wrong estimated execution time, so we skip it now.
|
||||
if (op->getInputs().size() != 1)
|
||||
return nullptr;
|
||||
nnet::MatchReshapeVisitor matchReshapeVisitor;
|
||||
// If this routine only change the shape, translate it to a Reshape
|
||||
if (matchReshapeVisitor(op->getExpr())) {
|
||||
|
@ -598,61 +602,34 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) {
|
|||
// return graph;
|
||||
// }
|
||||
|
||||
// Graph NMutator::transformTConv1x1(Operator op) {
|
||||
// if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
|
||||
// if (tconvOp->getPh() == 0 && tconvOp->getSh() == 1) {
|
||||
// auto g = new infini::Graph();
|
||||
// auto inputDims = tconvOp->getInputs(0)->getDims();
|
||||
// auto weightDims = tconvOp->getInputs(1)->getDims();
|
||||
// auto outputDims = tconvOp->getOutput()->getDims();
|
||||
// auto newA = g->tensor(
|
||||
// {inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
|
||||
// auto newW = g->tensor(
|
||||
// {weightDims[0] * weightDims[1] * weightDims[3],
|
||||
// weightDims[2]});
|
||||
// auto newO =
|
||||
// g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
|
||||
// weightDims[0] * weightDims[1] * weightDims[3]});
|
||||
// g->reshape(tconvOp->getInputs(0), newA);
|
||||
// g->reshape(tconvOp->getInputs(1), newW);
|
||||
// g->matmul(newA, newW, newO, 0, 1);
|
||||
// g->reshape(newO, tconvOp->getOutput());
|
||||
// g->updateConnection();
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// return graph;
|
||||
// }
|
||||
// }
|
||||
// return nullptr;
|
||||
// }
|
||||
|
||||
// Graph NMutator::transformConv1x1(Operator op) {
|
||||
// auto convOp = dynamic_cast<ConvOp *>(op);
|
||||
// if (!convOp)
|
||||
// return nullptr;
|
||||
// if (convOp->getPh() == 0 && convOp->getSh() == 1 &&
|
||||
// convOp->getInputs()[1]->getDims()[2] == 1 &&
|
||||
// convOp->getInputs()[1]->getDims()[3] == 1) {
|
||||
// // Transpose is requrired for BS>1
|
||||
// // if (convOp->getInputs()[0]->getDims()[0] == 1) {
|
||||
// auto g = new infini::Graph();
|
||||
// auto inputDims = convOp->getInputs(0)->getDims();
|
||||
// auto weightDims = convOp->getInputs(1)->getDims();
|
||||
// auto outputDims = convOp->getOutput()->getDims();
|
||||
// auto newA = g->tensor(
|
||||
// {inputDims[1], inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
// auto newW = g->tensor({weightDims[0], weightDims[1]});
|
||||
// auto newO = g->tensor(
|
||||
// {weightDims[0], inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
// g->reshape(convOp->getInputs(0), newA);
|
||||
// g->reshape(convOp->getInputs(1), newW);
|
||||
// g->matmul(newW, newA, newO, 0, 0);
|
||||
// g->reshape(newO, convOp->getOutput());
|
||||
// g->updateConnection();
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// return graph;
|
||||
// }
|
||||
// return nullptr;
|
||||
// }
|
||||
Graph NMutator::transformConv1x1(Operator _op) {
|
||||
auto op = as<ConvObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
Shape shapeA = op->getInputs(0)->getDims();
|
||||
Shape shapeW = op->getInputs(1)->getDims();
|
||||
// TODO: support batch size > 1
|
||||
if (shapeA[0] != 1)
|
||||
return nullptr;
|
||||
if (op->getPh() == 0 && op->getSh() == 1 && shapeW[2] == 1 &&
|
||||
shapeW[3] == 1) {
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto A =
|
||||
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
|
||||
vector{shapeA[1], shapeA[0] * shapeA[2] *
|
||||
shapeA[3]}) // [C, N*H*W]
|
||||
->getOutput();
|
||||
auto B = g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(1)), nullptr,
|
||||
vector{shapeW[0], shapeW[1]}) // [F, C]
|
||||
->getOutput();
|
||||
auto O =
|
||||
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, N*H*W]
|
||||
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
|
||||
op->getOutput()->getDims());
|
||||
return g;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Graph NMutator::transformConv1xk(Operator op) {
|
||||
// auto convOp = dynamic_cast<ConvOp *>(op);
|
||||
|
|
|
@ -122,6 +122,6 @@ if __name__ == "__main__":
|
|||
[3, 2, 2, 5, 8, 8, 6, 90])
|
||||
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
|
||||
|
||||
save_onnx(g, f"optimized_{name}.onnx")
|
||||
save_onnx(g, f"opt_{name}.onnx")
|
||||
verify_graphs(runtime, original_g, g)
|
||||
run_and_evaluate(runtime, g)
|
||||
|
|
Loading…
Reference in New Issue