diff --git a/include/core/mutator.h b/include/core/mutator.h index 0e446c14..ff871b4b 100644 --- a/include/core/mutator.h +++ b/include/core/mutator.h @@ -31,6 +31,10 @@ class Mutator { IT_TODO_HALT(); } virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); } + + /// @brief Eliminate transpose and reshape. + /// @return The graph after elimination. `nullptr` if fails. + virtual Graph eliminateVertically(const Graph &in_graph) { IT_TODO_HALT(); } }; } // namespace infini diff --git a/include/nnet/nmutator.h b/include/nnet/nmutator.h index 85c9f0e4..5457034c 100644 --- a/include/nnet/nmutator.h +++ b/include/nnet/nmutator.h @@ -26,6 +26,7 @@ class NMutator : public Mutator { vector run(const Graph &in_graph) override; Graph fuseVertically(const Graph &in_graph) override; + Graph eliminateVertically(const Graph &in_graph) override; bool isMultiBranchMergable(const Graph &in_graph) override; void setToNaiveMembound(); @@ -68,6 +69,14 @@ class NMutator : public Mutator { Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize, Tensor output = nullptr); + + /// @brief Construct a new graph with a chain of operators. Use the output + /// from the previous operator as the input of the next operator. While + /// constructing, the input and output tensors from inputGraph are used as + /// new constructed graph. + /// @param op The operator chain. It can have wrong input/output shapes. + /// @return + Graph constructGraphByOperatorChain(vector ops, Graph inputGraph); }; } // namespace infini diff --git a/include/operators/transpose.h b/include/operators/transpose.h index 61dc8e5a..3b771699 100644 --- a/include/operators/transpose.h +++ b/include/operators/transpose.h @@ -3,6 +3,8 @@ namespace infini { class TransposeObj : public OperatorObj { + vector transposePermute; + public: TransposeObj(GraphObj *graph, Tensor input, Tensor output, vector permute); @@ -15,7 +17,6 @@ class TransposeObj : public OperatorObj { std::vector getPermute() const { return transposePermute; } private: - vector transposePermute; vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; diff --git a/src/core/operator.cc b/src/core/operator.cc index b5968855..47bef3df 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -34,9 +34,10 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; } bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; } bool OperatorObj::isMemBoundOp() const { - return type == OpType::MemBound || type == OpType::Activation || - type == OpType::Transpose || type == OpType::Relu || - type == OpType::Tanh; + return type == OpType::MemBound || type == OpType::Reshape || + type == OpType::Activation || type == OpType::Transpose || + type == OpType::Relu || type == OpType::Tanh || + type == OpType::Softmax; } void OperatorObj::removePredecessors(const Operator &op) { diff --git a/src/nnet/nmutator.cc b/src/nnet/nmutator.cc index 036e2992..df60107e 100644 --- a/src/nnet/nmutator.cc +++ b/src/nnet/nmutator.cc @@ -12,6 +12,7 @@ #include "operators/matmul.h" #include "operators/membound.h" #include "operators/reshape.h" +#include "operators/softmax.h" #include "operators/transpose.h" #include "operators/unary.h" @@ -684,6 +685,142 @@ Graph NMutator::transformConv1x1(Operator _op) { // return graph; // } +Graph NMutator::constructGraphByOperatorChain(vector ops, + Graph inputGraph) { + // Construct new graph + auto g = make_ref(runtime); + IT_ASSERT(inputGraph->getInputs().size() == 1); + IT_ASSERT(inputGraph->getOutputs().size() == 1); + IT_ASSERT(ops.size() > 0, + "TODO: If there is no op left, how to return an empty graph?"); + auto input = g->cloneTensor(inputGraph->getInputs()[0]); + for (size_t i = 0; i < ops.size(); ++i) { + auto output = (i + 1 == ops.size()) + ? inputGraph->getOutputs()[0] + : g->addTensor(ops[i]->getOutput()->getDims()); + dbg(input->getDims(), output->getDims()); + input = g->cloneOperator(ops[i], {input}, {output})->getOutput(); + } + return g; +} + +Graph NMutator::eliminateVertically(const Graph &inputGraph) { + auto ops = inputGraph->getOperators(); + + IT_ASSERT(!ops.empty()); + for (auto &op : ops) { + IT_ASSERT(op->isMemBoundOp()); + IT_ASSERT_TODO(op->getInputs().size() == 1); + IT_ASSERT(op->getOutputs().size() == 1); + } + if (ops.size() == 1) { + return make_ref(runtime, ops); + } + + // Set attributs for operators. + // isComputation: is computaiton + // isElementwise: do elementwise computations + // lastRowSwapable: do last-channel-wise computations, which includes + // elementwise as a special case. + auto classifyOperator = [](Operator op) { + auto type = op->getOpType(); + bool isComputation = + type != OpType::Reshape && type != OpType::Transpose; + bool isElementwise = + !isComputation || (type == OpType::Relu || type == OpType::Tanh); + bool lastRowSwapable = false; + if (isComputation) + lastRowSwapable = isElementwise || // Softmax along the last dim + (type == OpType::Softmax && + as(op)->getAxis() == + int(op->getOutput()->getDims().size()) - 1); + else { + if (auto t = as(op)) { + // Last dim remains unchanged + lastRowSwapable = + (t->getPermute().back() == int(t->getPermute().size()) - 1); + } else if (auto t = as(op)) { + // Last dim remains unchanged + lastRowSwapable = (t->getInputs(0)->getDims().back() == + t->getOutput()->getDims().back()); + } + } + return tuple{isComputation, isElementwise, lastRowSwapable}; + }; + + // Reorder operators: move computatation operators to the tail + for (int i = ops.size() - 2; i >= 0; --i) { + for (int j = i; j < int(ops.size()) - 1; ++j) { + bool swapable = false; + const set unaryElementwise{OpType::Relu, OpType::PRelu, + OpType::Tanh}; + auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]); + auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]); + if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition + swapable = true; + if (swapable) { + swap(ops[j], ops[j + 1]); + } + } + } + + Graph g = constructGraphByOperatorChain(ops, inputGraph); + // Eliminate operators + bool haveElimination; + do { + haveElimination = false; + ops = g->getOperators(); + vector newOps; + for (int i = 0; i < int(ops.size()); ++i) { + // Eliminate identity operators + if (auto op = as(ops[i])) { + auto perm = op->getPermute(); + int j = 0; + for (j = 0; j < int(perm.size()); ++j) + if (j != perm[j]) + break; + if (j == int(perm.size())) { + haveElimination = true; + continue; + } + } else if (auto op = as(ops[i])) { + if (op->getShape() == op->getInputs(0)->getDims()) { + haveElimination = true; + continue; + } + } + + // Eliminate reciprocal operators + if (i + 1 == (int)ops.size() || + (ops[i]->getOpType() != ops[i + 1]->getOpType())) { + newOps.push_back(ops[i]); + continue; + } + if (ops[i]->getOpType() == OpType::Reshape) { + newOps.push_back(make_ref( + nullptr, ops[i]->getInputs(0), ops[i + 1]->getOutput())); + ++i; + haveElimination = true; + } else if (ops[i]->getOpType() == OpType::Transpose) { + auto permuteA = as(ops[i])->getPermute(); + auto permuteB = as(ops[i + 1])->getPermute(); + vector permute; + for (auto p : permuteB) + permute.push_back(permuteA[p]); + newOps.push_back( + make_ref(nullptr, ops[i]->getInputs(0), + ops[i + 1]->getOutput(), permute)); + ++i; + haveElimination = true; + } else { + newOps.push_back(ops[i]); + } + } + g = constructGraphByOperatorChain(newOps, inputGraph); + } while (haveElimination); + return g; +} + Graph NMutator::fuseVertically(const Graph &inputGraph) { Graph optGraph = make_ref(runtime); @@ -780,7 +917,7 @@ Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize, shapeNew.emplace_back(shapeOrignial[i]); auto A1 = g->addOp(A, nullptr, shapeNew)->getOutput(); auto A2 = - g->addOp(A1, nullptr, vector{0, 1, 3, 2})->getOutput(); + g->addOp(A1, nullptr, vector{0, 2, 1, 3})->getOutput(); Tensor A3; if (output) A3 = g->addOpWithOutputs(A2, output, shapeOrignial) diff --git a/src/operators/reshape.cc b/src/operators/reshape.cc index 8576809d..2be9e944 100644 --- a/src/operators/reshape.cc +++ b/src/operators/reshape.cc @@ -20,9 +20,9 @@ optional> ReshapeObj::inferShape(const TensorVec &inputs) const { std::string ReshapeObj::toString() const { std::ostringstream os; os << "Reshape[" << getGuid() << "]"; - os << "("; + os << "(input dim="; os << vecToString(inputs[0]->getDims()) << ","; - os << "dims=" << vecToString(dims) << ","; + os << "output dims=" << vecToString(dims) << ","; os << "input=" << inputs[0]->getGuid() << ","; os << "output=" << outputs[0]->getGuid() << ")"; return os.str(); diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index 02d35b62..da525210 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -27,7 +27,8 @@ std::string TransposeObj::toString() const { os << "("; os << vecToString(inputs[0]->getDims()) << ","; os << "input=" << inputs[0]->getGuid() << ","; - os << "output=" << outputs[0]->getGuid() << ")"; + os << "output=" << outputs[0]->getGuid() << ","; + os << "perm=" << vecToString(transposePermute) << ")"; return os.str(); } diff --git a/test/nnet/test_mutator.cc b/test/nnet/test_mutator.cc index 9219c526..cfdbf90e 100644 --- a/test/nnet/test_mutator.cc +++ b/test/nnet/test_mutator.cc @@ -7,6 +7,9 @@ #include "nnet/nmutator.h" #include "nnet/test.h" #include "operators/conv.h" +#include "operators/reshape.h" +#include "operators/softmax.h" +#include "operators/transpose.h" #include "test.h" namespace infini { @@ -479,4 +482,73 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) { // bestGraph->print(); // EXPECT_TRUE(graph->verification(bestGraph.get(), true)); // } + +TEST(NMutator, eliminateVertically_RTSTR) { + Runtime runtime = make_ref(); + Graph g = make_ref(runtime); + const int a = 8, b = 4, c = 5, d = 30; + auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input); + auto input = t0; + t0 = g->addOp(t0, nullptr, Shape{a, b, c, d})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{0, 2, 1, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, 3)->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{0, 2, 1, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{a, b * c, d})->getOutput(); + auto mutator = make_ref(); + auto optG = mutator->eliminateVertically(g); + dbg(optG); + ASSERT_EQ(optG->getOperators().size(), 1u); + auto op = optG->getOperators()[0]; + EXPECT_EQ(op->getOpType(), OpType::Softmax); + EXPECT_EQ(op->getInputs(0)->getFuid(), input->getFuid()); + EXPECT_EQ(op->getOutput()->getFuid(), t0->getFuid()); +} + +TEST(NMutator, eliminateVertically_RTST) { + Runtime runtime = make_ref(); + Graph g = make_ref(runtime); + const int a = 8, b = 4, c = 5, d = 30; + auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input); + t0 = g->addOp(t0, nullptr, Shape{a, b, c, d})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{0, 2, 1, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, 3)->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{0, 2, 1, 3})->getOutput(); + auto mutator = make_ref(); + auto optG = mutator->eliminateVertically(g); + dbg(optG); + ASSERT_EQ(optG->getOperators().size(), 2u); +} + +TEST(NMutator, eliminateVertically_RTSTR_3d) { + Runtime runtime = make_ref(); + Graph g = make_ref(runtime); + const int a = 8, b = 4, c = 5, d = 30; + auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input); + t0 = g->addOp(t0, nullptr, Shape{a, b, c, d})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{1, 2, 0, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, 3)->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{2, 0, 1, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{a, b * c, d})->getOutput(); + auto mutator = make_ref(); + auto optG = mutator->eliminateVertically(g); + dbg(optG); + EXPECT_EQ(optG->getOperators().size(), 1u); +} + +TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) { + Runtime runtime = make_ref(); + Graph g = make_ref(runtime); + const int a = 8, b = 4, c = 5, d = 30; + auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input); + t0 = g->addOp(t0, nullptr, Shape{a, b, c, d})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{1, 2, 0, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, 2)->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{2, 0, 1, 3})->getOutput(); + t0 = g->addOp(t0, nullptr, Shape{a, b * c, d})->getOutput(); + auto mutator = make_ref(); + auto optG = mutator->eliminateVertically(g); + dbg(optG); + EXPECT_EQ(optG->getOperators().size(), 5u); +} + } // namespace infini