forked from jiuyuan/InfiniTensor
Add: Reshape/Transpose elimination
This commit is contained in:
parent
acc64fd32c
commit
7277356744
|
@ -31,6 +31,10 @@ class Mutator {
|
|||
IT_TODO_HALT();
|
||||
}
|
||||
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
|
||||
|
||||
/// @brief Eliminate transpose and reshape.
|
||||
/// @return The graph after elimination. `nullptr` if fails.
|
||||
virtual Graph eliminateVertically(const Graph &in_graph) { IT_TODO_HALT(); }
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -26,6 +26,7 @@ class NMutator : public Mutator {
|
|||
|
||||
vector<Graph> run(const Graph &in_graph) override;
|
||||
Graph fuseVertically(const Graph &in_graph) override;
|
||||
Graph eliminateVertically(const Graph &in_graph) override;
|
||||
bool isMultiBranchMergable(const Graph &in_graph) override;
|
||||
|
||||
void setToNaiveMembound();
|
||||
|
@ -68,6 +69,14 @@ class NMutator : public Mutator {
|
|||
|
||||
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||
Tensor output = nullptr);
|
||||
|
||||
/// @brief Construct a new graph with a chain of operators. Use the output
|
||||
/// from the previous operator as the input of the next operator. While
|
||||
/// constructing, the input and output tensors from inputGraph are used as
|
||||
/// new constructed graph.
|
||||
/// @param op The operator chain. It can have wrong input/output shapes.
|
||||
/// @return
|
||||
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
namespace infini {
|
||||
class TransposeObj : public OperatorObj {
|
||||
vector<int> transposePermute;
|
||||
|
||||
public:
|
||||
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
vector<int> permute);
|
||||
|
@ -15,7 +17,6 @@ class TransposeObj : public OperatorObj {
|
|||
std::vector<int> getPermute() const { return transposePermute; }
|
||||
|
||||
private:
|
||||
vector<int> transposePermute;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
|
|
@ -34,9 +34,10 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
|||
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
|
||||
|
||||
bool OperatorObj::isMemBoundOp() const {
|
||||
return type == OpType::MemBound || type == OpType::Activation ||
|
||||
type == OpType::Transpose || type == OpType::Relu ||
|
||||
type == OpType::Tanh;
|
||||
return type == OpType::MemBound || type == OpType::Reshape ||
|
||||
type == OpType::Activation || type == OpType::Transpose ||
|
||||
type == OpType::Relu || type == OpType::Tanh ||
|
||||
type == OpType::Softmax;
|
||||
}
|
||||
|
||||
void OperatorObj::removePredecessors(const Operator &op) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
|
@ -684,6 +685,142 @@ Graph NMutator::transformConv1x1(Operator _op) {
|
|||
// return graph;
|
||||
// }
|
||||
|
||||
Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
|
||||
Graph inputGraph) {
|
||||
// Construct new graph
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
IT_ASSERT(inputGraph->getInputs().size() == 1);
|
||||
IT_ASSERT(inputGraph->getOutputs().size() == 1);
|
||||
IT_ASSERT(ops.size() > 0,
|
||||
"TODO: If there is no op left, how to return an empty graph?");
|
||||
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
auto output = (i + 1 == ops.size())
|
||||
? inputGraph->getOutputs()[0]
|
||||
: g->addTensor(ops[i]->getOutput()->getDims());
|
||||
dbg(input->getDims(), output->getDims());
|
||||
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
||||
auto ops = inputGraph->getOperators();
|
||||
|
||||
IT_ASSERT(!ops.empty());
|
||||
for (auto &op : ops) {
|
||||
IT_ASSERT(op->isMemBoundOp());
|
||||
IT_ASSERT_TODO(op->getInputs().size() == 1);
|
||||
IT_ASSERT(op->getOutputs().size() == 1);
|
||||
}
|
||||
if (ops.size() == 1) {
|
||||
return make_ref<GraphObj>(runtime, ops);
|
||||
}
|
||||
|
||||
// Set attributs for operators.
|
||||
// isComputation: is computaiton
|
||||
// isElementwise: do elementwise computations
|
||||
// lastRowSwapable: do last-channel-wise computations, which includes
|
||||
// elementwise as a special case.
|
||||
auto classifyOperator = [](Operator op) {
|
||||
auto type = op->getOpType();
|
||||
bool isComputation =
|
||||
type != OpType::Reshape && type != OpType::Transpose;
|
||||
bool isElementwise =
|
||||
!isComputation || (type == OpType::Relu || type == OpType::Tanh);
|
||||
bool lastRowSwapable = false;
|
||||
if (isComputation)
|
||||
lastRowSwapable = isElementwise || // Softmax along the last dim
|
||||
(type == OpType::Softmax &&
|
||||
as<SoftmaxObj>(op)->getAxis() ==
|
||||
int(op->getOutput()->getDims().size()) - 1);
|
||||
else {
|
||||
if (auto t = as<TransposeObj>(op)) {
|
||||
// Last dim remains unchanged
|
||||
lastRowSwapable =
|
||||
(t->getPermute().back() == int(t->getPermute().size()) - 1);
|
||||
} else if (auto t = as<ReshapeObj>(op)) {
|
||||
// Last dim remains unchanged
|
||||
lastRowSwapable = (t->getInputs(0)->getDims().back() ==
|
||||
t->getOutput()->getDims().back());
|
||||
}
|
||||
}
|
||||
return tuple{isComputation, isElementwise, lastRowSwapable};
|
||||
};
|
||||
|
||||
// Reorder operators: move computatation operators to the tail
|
||||
for (int i = ops.size() - 2; i >= 0; --i) {
|
||||
for (int j = i; j < int(ops.size()) - 1; ++j) {
|
||||
bool swapable = false;
|
||||
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
|
||||
OpType::Tanh};
|
||||
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
|
||||
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
|
||||
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
|
||||
swapable = true;
|
||||
if (swapable) {
|
||||
swap(ops[j], ops[j + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Graph g = constructGraphByOperatorChain(ops, inputGraph);
|
||||
// Eliminate operators
|
||||
bool haveElimination;
|
||||
do {
|
||||
haveElimination = false;
|
||||
ops = g->getOperators();
|
||||
vector<Operator> newOps;
|
||||
for (int i = 0; i < int(ops.size()); ++i) {
|
||||
// Eliminate identity operators
|
||||
if (auto op = as<TransposeObj>(ops[i])) {
|
||||
auto perm = op->getPermute();
|
||||
int j = 0;
|
||||
for (j = 0; j < int(perm.size()); ++j)
|
||||
if (j != perm[j])
|
||||
break;
|
||||
if (j == int(perm.size())) {
|
||||
haveElimination = true;
|
||||
continue;
|
||||
}
|
||||
} else if (auto op = as<ReshapeObj>(ops[i])) {
|
||||
if (op->getShape() == op->getInputs(0)->getDims()) {
|
||||
haveElimination = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate reciprocal operators
|
||||
if (i + 1 == (int)ops.size() ||
|
||||
(ops[i]->getOpType() != ops[i + 1]->getOpType())) {
|
||||
newOps.push_back(ops[i]);
|
||||
continue;
|
||||
}
|
||||
if (ops[i]->getOpType() == OpType::Reshape) {
|
||||
newOps.push_back(make_ref<ReshapeObj>(
|
||||
nullptr, ops[i]->getInputs(0), ops[i + 1]->getOutput()));
|
||||
++i;
|
||||
haveElimination = true;
|
||||
} else if (ops[i]->getOpType() == OpType::Transpose) {
|
||||
auto permuteA = as<TransposeObj>(ops[i])->getPermute();
|
||||
auto permuteB = as<TransposeObj>(ops[i + 1])->getPermute();
|
||||
vector<int> permute;
|
||||
for (auto p : permuteB)
|
||||
permute.push_back(permuteA[p]);
|
||||
newOps.push_back(
|
||||
make_ref<TransposeObj>(nullptr, ops[i]->getInputs(0),
|
||||
ops[i + 1]->getOutput(), permute));
|
||||
++i;
|
||||
haveElimination = true;
|
||||
} else {
|
||||
newOps.push_back(ops[i]);
|
||||
}
|
||||
}
|
||||
g = constructGraphByOperatorChain(newOps, inputGraph);
|
||||
} while (haveElimination);
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::fuseVertically(const Graph &inputGraph) {
|
||||
Graph optGraph = make_ref<GraphObj>(runtime);
|
||||
|
||||
|
@ -780,7 +917,7 @@ Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
|||
shapeNew.emplace_back(shapeOrignial[i]);
|
||||
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
|
||||
auto A2 =
|
||||
g->addOp<TransposeObj>(A1, nullptr, vector{0, 1, 3, 2})->getOutput();
|
||||
g->addOp<TransposeObj>(A1, nullptr, vector{0, 2, 1, 3})->getOutput();
|
||||
Tensor A3;
|
||||
if (output)
|
||||
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)
|
||||
|
|
|
@ -20,9 +20,9 @@ optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
|
|||
std::string ReshapeObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Reshape[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << "(input dim=";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "dims=" << vecToString(dims) << ",";
|
||||
os << "output dims=" << vecToString(dims) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
|
|
|
@ -27,7 +27,8 @@ std::string TransposeObj::toString() const {
|
|||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
os << "output=" << outputs[0]->getGuid() << ",";
|
||||
os << "perm=" << vecToString(transposePermute) << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
#include "nnet/nmutator.h"
|
||||
#include "nnet/test.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -479,4 +482,73 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) {
|
|||
// bestGraph->print();
|
||||
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
|
||||
// }
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTSTR) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
auto input = t0;
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
ASSERT_EQ(optG->getOperators().size(), 1u);
|
||||
auto op = optG->getOperators()[0];
|
||||
EXPECT_EQ(op->getOpType(), OpType::Softmax);
|
||||
EXPECT_EQ(op->getInputs(0)->getFuid(), input->getFuid());
|
||||
EXPECT_EQ(op->getOutput()->getFuid(), t0->getFuid());
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTST) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
ASSERT_EQ(optG->getOperators().size(), 2u);
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTSTR_3d) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
EXPECT_EQ(optG->getOperators().size(), 1u);
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 2)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
EXPECT_EQ(optG->getOperators().size(), 5u);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue