Add: Reshape/Transpose elimination

This commit is contained in:
Liyan Zheng 2023-04-23 02:10:05 +08:00
parent acc64fd32c
commit 7277356744
8 changed files with 233 additions and 8 deletions

View File

@ -31,6 +31,10 @@ class Mutator {
IT_TODO_HALT();
}
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
/// @brief Eliminate transpose and reshape.
/// @return The graph after elimination. `nullptr` if fails.
virtual Graph eliminateVertically(const Graph &in_graph) { IT_TODO_HALT(); }
};
} // namespace infini

View File

@ -26,6 +26,7 @@ class NMutator : public Mutator {
vector<Graph> run(const Graph &in_graph) override;
Graph fuseVertically(const Graph &in_graph) override;
Graph eliminateVertically(const Graph &in_graph) override;
bool isMultiBranchMergable(const Graph &in_graph) override;
void setToNaiveMembound();
@ -68,6 +69,14 @@ class NMutator : public Mutator {
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
Tensor output = nullptr);
/// @brief Construct a new graph with a chain of operators. Use the output
/// from the previous operator as the input of the next operator. While
/// constructing, the input and output tensors from inputGraph are used as
/// new constructed graph.
/// @param op The operator chain. It can have wrong input/output shapes.
/// @return
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
};
} // namespace infini

View File

@ -3,6 +3,8 @@
namespace infini {
class TransposeObj : public OperatorObj {
vector<int> transposePermute;
public:
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
vector<int> permute);
@ -15,7 +17,6 @@ class TransposeObj : public OperatorObj {
std::vector<int> getPermute() const { return transposePermute; }
private:
vector<int> transposePermute;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};

View File

@ -34,9 +34,10 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorObj::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose || type == OpType::Relu ||
type == OpType::Tanh;
return type == OpType::MemBound || type == OpType::Reshape ||
type == OpType::Activation || type == OpType::Transpose ||
type == OpType::Relu || type == OpType::Tanh ||
type == OpType::Softmax;
}
void OperatorObj::removePredecessors(const Operator &op) {

View File

@ -12,6 +12,7 @@
#include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/reshape.h"
#include "operators/softmax.h"
#include "operators/transpose.h"
#include "operators/unary.h"
@ -684,6 +685,142 @@ Graph NMutator::transformConv1x1(Operator _op) {
// return graph;
// }
Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
Graph inputGraph) {
// Construct new graph
auto g = make_ref<GraphObj>(runtime);
IT_ASSERT(inputGraph->getInputs().size() == 1);
IT_ASSERT(inputGraph->getOutputs().size() == 1);
IT_ASSERT(ops.size() > 0,
"TODO: If there is no op left, how to return an empty graph?");
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
for (size_t i = 0; i < ops.size(); ++i) {
auto output = (i + 1 == ops.size())
? inputGraph->getOutputs()[0]
: g->addTensor(ops[i]->getOutput()->getDims());
dbg(input->getDims(), output->getDims());
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
}
return g;
}
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
auto ops = inputGraph->getOperators();
IT_ASSERT(!ops.empty());
for (auto &op : ops) {
IT_ASSERT(op->isMemBoundOp());
IT_ASSERT_TODO(op->getInputs().size() == 1);
IT_ASSERT(op->getOutputs().size() == 1);
}
if (ops.size() == 1) {
return make_ref<GraphObj>(runtime, ops);
}
// Set attributs for operators.
// isComputation: is computaiton
// isElementwise: do elementwise computations
// lastRowSwapable: do last-channel-wise computations, which includes
// elementwise as a special case.
auto classifyOperator = [](Operator op) {
auto type = op->getOpType();
bool isComputation =
type != OpType::Reshape && type != OpType::Transpose;
bool isElementwise =
!isComputation || (type == OpType::Relu || type == OpType::Tanh);
bool lastRowSwapable = false;
if (isComputation)
lastRowSwapable = isElementwise || // Softmax along the last dim
(type == OpType::Softmax &&
as<SoftmaxObj>(op)->getAxis() ==
int(op->getOutput()->getDims().size()) - 1);
else {
if (auto t = as<TransposeObj>(op)) {
// Last dim remains unchanged
lastRowSwapable =
(t->getPermute().back() == int(t->getPermute().size()) - 1);
} else if (auto t = as<ReshapeObj>(op)) {
// Last dim remains unchanged
lastRowSwapable = (t->getInputs(0)->getDims().back() ==
t->getOutput()->getDims().back());
}
}
return tuple{isComputation, isElementwise, lastRowSwapable};
};
// Reorder operators: move computatation operators to the tail
for (int i = ops.size() - 2; i >= 0; --i) {
for (int j = i; j < int(ops.size()) - 1; ++j) {
bool swapable = false;
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
OpType::Tanh};
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
swapable = true;
if (swapable) {
swap(ops[j], ops[j + 1]);
}
}
}
Graph g = constructGraphByOperatorChain(ops, inputGraph);
// Eliminate operators
bool haveElimination;
do {
haveElimination = false;
ops = g->getOperators();
vector<Operator> newOps;
for (int i = 0; i < int(ops.size()); ++i) {
// Eliminate identity operators
if (auto op = as<TransposeObj>(ops[i])) {
auto perm = op->getPermute();
int j = 0;
for (j = 0; j < int(perm.size()); ++j)
if (j != perm[j])
break;
if (j == int(perm.size())) {
haveElimination = true;
continue;
}
} else if (auto op = as<ReshapeObj>(ops[i])) {
if (op->getShape() == op->getInputs(0)->getDims()) {
haveElimination = true;
continue;
}
}
// Eliminate reciprocal operators
if (i + 1 == (int)ops.size() ||
(ops[i]->getOpType() != ops[i + 1]->getOpType())) {
newOps.push_back(ops[i]);
continue;
}
if (ops[i]->getOpType() == OpType::Reshape) {
newOps.push_back(make_ref<ReshapeObj>(
nullptr, ops[i]->getInputs(0), ops[i + 1]->getOutput()));
++i;
haveElimination = true;
} else if (ops[i]->getOpType() == OpType::Transpose) {
auto permuteA = as<TransposeObj>(ops[i])->getPermute();
auto permuteB = as<TransposeObj>(ops[i + 1])->getPermute();
vector<int> permute;
for (auto p : permuteB)
permute.push_back(permuteA[p]);
newOps.push_back(
make_ref<TransposeObj>(nullptr, ops[i]->getInputs(0),
ops[i + 1]->getOutput(), permute));
++i;
haveElimination = true;
} else {
newOps.push_back(ops[i]);
}
}
g = constructGraphByOperatorChain(newOps, inputGraph);
} while (haveElimination);
return g;
}
Graph NMutator::fuseVertically(const Graph &inputGraph) {
Graph optGraph = make_ref<GraphObj>(runtime);
@ -780,7 +917,7 @@ Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
shapeNew.emplace_back(shapeOrignial[i]);
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
auto A2 =
g->addOp<TransposeObj>(A1, nullptr, vector{0, 1, 3, 2})->getOutput();
g->addOp<TransposeObj>(A1, nullptr, vector{0, 2, 1, 3})->getOutput();
Tensor A3;
if (output)
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)

View File

@ -20,9 +20,9 @@ optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
std::string ReshapeObj::toString() const {
std::ostringstream os;
os << "Reshape[" << getGuid() << "]";
os << "(";
os << "(input dim=";
os << vecToString(inputs[0]->getDims()) << ",";
os << "dims=" << vecToString(dims) << ",";
os << "output dims=" << vecToString(dims) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();

View File

@ -27,7 +27,8 @@ std::string TransposeObj::toString() const {
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
os << "output=" << outputs[0]->getGuid() << ",";
os << "perm=" << vecToString(transposePermute) << ")";
return os.str();
}

View File

@ -7,6 +7,9 @@
#include "nnet/nmutator.h"
#include "nnet/test.h"
#include "operators/conv.h"
#include "operators/reshape.h"
#include "operators/softmax.h"
#include "operators/transpose.h"
#include "test.h"
namespace infini {
@ -479,4 +482,73 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) {
// bestGraph->print();
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
// }
TEST(NMutator, eliminateVertically_RTSTR) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
auto input = t0;
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
ASSERT_EQ(optG->getOperators().size(), 1u);
auto op = optG->getOperators()[0];
EXPECT_EQ(op->getOpType(), OpType::Softmax);
EXPECT_EQ(op->getInputs(0)->getFuid(), input->getFuid());
EXPECT_EQ(op->getOutput()->getFuid(), t0->getFuid());
}
TEST(NMutator, eliminateVertically_RTST) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
ASSERT_EQ(optG->getOperators().size(), 2u);
}
TEST(NMutator, eliminateVertically_RTSTR_3d) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
EXPECT_EQ(optG->getOperators().size(), 1u);
}
TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 2)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
EXPECT_EQ(optG->getOperators().size(), 5u);
}
} // namespace infini