forked from jiuyuan/InfiniTensor
Add: Reshape/Transpose elimination
This commit is contained in:
parent
acc64fd32c
commit
7277356744
|
@ -31,6 +31,10 @@ class Mutator {
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
|
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
|
||||||
|
|
||||||
|
/// @brief Eliminate transpose and reshape.
|
||||||
|
/// @return The graph after elimination. `nullptr` if fails.
|
||||||
|
virtual Graph eliminateVertically(const Graph &in_graph) { IT_TODO_HALT(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -26,6 +26,7 @@ class NMutator : public Mutator {
|
||||||
|
|
||||||
vector<Graph> run(const Graph &in_graph) override;
|
vector<Graph> run(const Graph &in_graph) override;
|
||||||
Graph fuseVertically(const Graph &in_graph) override;
|
Graph fuseVertically(const Graph &in_graph) override;
|
||||||
|
Graph eliminateVertically(const Graph &in_graph) override;
|
||||||
bool isMultiBranchMergable(const Graph &in_graph) override;
|
bool isMultiBranchMergable(const Graph &in_graph) override;
|
||||||
|
|
||||||
void setToNaiveMembound();
|
void setToNaiveMembound();
|
||||||
|
@ -68,6 +69,14 @@ class NMutator : public Mutator {
|
||||||
|
|
||||||
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||||
Tensor output = nullptr);
|
Tensor output = nullptr);
|
||||||
|
|
||||||
|
/// @brief Construct a new graph with a chain of operators. Use the output
|
||||||
|
/// from the previous operator as the input of the next operator. While
|
||||||
|
/// constructing, the input and output tensors from inputGraph are used as
|
||||||
|
/// new constructed graph.
|
||||||
|
/// @param op The operator chain. It can have wrong input/output shapes.
|
||||||
|
/// @return
|
||||||
|
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
class TransposeObj : public OperatorObj {
|
class TransposeObj : public OperatorObj {
|
||||||
|
vector<int> transposePermute;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
vector<int> permute);
|
vector<int> permute);
|
||||||
|
@ -15,7 +17,6 @@ class TransposeObj : public OperatorObj {
|
||||||
std::vector<int> getPermute() const { return transposePermute; }
|
std::vector<int> getPermute() const { return transposePermute; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> transposePermute;
|
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -34,9 +34,10 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
||||||
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
|
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
|
||||||
|
|
||||||
bool OperatorObj::isMemBoundOp() const {
|
bool OperatorObj::isMemBoundOp() const {
|
||||||
return type == OpType::MemBound || type == OpType::Activation ||
|
return type == OpType::MemBound || type == OpType::Reshape ||
|
||||||
type == OpType::Transpose || type == OpType::Relu ||
|
type == OpType::Activation || type == OpType::Transpose ||
|
||||||
type == OpType::Tanh;
|
type == OpType::Relu || type == OpType::Tanh ||
|
||||||
|
type == OpType::Softmax;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OperatorObj::removePredecessors(const Operator &op) {
|
void OperatorObj::removePredecessors(const Operator &op) {
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/membound.h"
|
#include "operators/membound.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
#include "operators/softmax.h"
|
||||||
#include "operators/transpose.h"
|
#include "operators/transpose.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
|
||||||
|
@ -684,6 +685,142 @@ Graph NMutator::transformConv1x1(Operator _op) {
|
||||||
// return graph;
|
// return graph;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
|
||||||
|
Graph inputGraph) {
|
||||||
|
// Construct new graph
|
||||||
|
auto g = make_ref<GraphObj>(runtime);
|
||||||
|
IT_ASSERT(inputGraph->getInputs().size() == 1);
|
||||||
|
IT_ASSERT(inputGraph->getOutputs().size() == 1);
|
||||||
|
IT_ASSERT(ops.size() > 0,
|
||||||
|
"TODO: If there is no op left, how to return an empty graph?");
|
||||||
|
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
|
||||||
|
for (size_t i = 0; i < ops.size(); ++i) {
|
||||||
|
auto output = (i + 1 == ops.size())
|
||||||
|
? inputGraph->getOutputs()[0]
|
||||||
|
: g->addTensor(ops[i]->getOutput()->getDims());
|
||||||
|
dbg(input->getDims(), output->getDims());
|
||||||
|
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
|
||||||
|
}
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
||||||
|
auto ops = inputGraph->getOperators();
|
||||||
|
|
||||||
|
IT_ASSERT(!ops.empty());
|
||||||
|
for (auto &op : ops) {
|
||||||
|
IT_ASSERT(op->isMemBoundOp());
|
||||||
|
IT_ASSERT_TODO(op->getInputs().size() == 1);
|
||||||
|
IT_ASSERT(op->getOutputs().size() == 1);
|
||||||
|
}
|
||||||
|
if (ops.size() == 1) {
|
||||||
|
return make_ref<GraphObj>(runtime, ops);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set attributs for operators.
|
||||||
|
// isComputation: is computaiton
|
||||||
|
// isElementwise: do elementwise computations
|
||||||
|
// lastRowSwapable: do last-channel-wise computations, which includes
|
||||||
|
// elementwise as a special case.
|
||||||
|
auto classifyOperator = [](Operator op) {
|
||||||
|
auto type = op->getOpType();
|
||||||
|
bool isComputation =
|
||||||
|
type != OpType::Reshape && type != OpType::Transpose;
|
||||||
|
bool isElementwise =
|
||||||
|
!isComputation || (type == OpType::Relu || type == OpType::Tanh);
|
||||||
|
bool lastRowSwapable = false;
|
||||||
|
if (isComputation)
|
||||||
|
lastRowSwapable = isElementwise || // Softmax along the last dim
|
||||||
|
(type == OpType::Softmax &&
|
||||||
|
as<SoftmaxObj>(op)->getAxis() ==
|
||||||
|
int(op->getOutput()->getDims().size()) - 1);
|
||||||
|
else {
|
||||||
|
if (auto t = as<TransposeObj>(op)) {
|
||||||
|
// Last dim remains unchanged
|
||||||
|
lastRowSwapable =
|
||||||
|
(t->getPermute().back() == int(t->getPermute().size()) - 1);
|
||||||
|
} else if (auto t = as<ReshapeObj>(op)) {
|
||||||
|
// Last dim remains unchanged
|
||||||
|
lastRowSwapable = (t->getInputs(0)->getDims().back() ==
|
||||||
|
t->getOutput()->getDims().back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tuple{isComputation, isElementwise, lastRowSwapable};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reorder operators: move computatation operators to the tail
|
||||||
|
for (int i = ops.size() - 2; i >= 0; --i) {
|
||||||
|
for (int j = i; j < int(ops.size()) - 1; ++j) {
|
||||||
|
bool swapable = false;
|
||||||
|
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
|
||||||
|
OpType::Tanh};
|
||||||
|
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
|
||||||
|
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
|
||||||
|
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
|
||||||
|
swapable = true;
|
||||||
|
if (swapable) {
|
||||||
|
swap(ops[j], ops[j + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph g = constructGraphByOperatorChain(ops, inputGraph);
|
||||||
|
// Eliminate operators
|
||||||
|
bool haveElimination;
|
||||||
|
do {
|
||||||
|
haveElimination = false;
|
||||||
|
ops = g->getOperators();
|
||||||
|
vector<Operator> newOps;
|
||||||
|
for (int i = 0; i < int(ops.size()); ++i) {
|
||||||
|
// Eliminate identity operators
|
||||||
|
if (auto op = as<TransposeObj>(ops[i])) {
|
||||||
|
auto perm = op->getPermute();
|
||||||
|
int j = 0;
|
||||||
|
for (j = 0; j < int(perm.size()); ++j)
|
||||||
|
if (j != perm[j])
|
||||||
|
break;
|
||||||
|
if (j == int(perm.size())) {
|
||||||
|
haveElimination = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else if (auto op = as<ReshapeObj>(ops[i])) {
|
||||||
|
if (op->getShape() == op->getInputs(0)->getDims()) {
|
||||||
|
haveElimination = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eliminate reciprocal operators
|
||||||
|
if (i + 1 == (int)ops.size() ||
|
||||||
|
(ops[i]->getOpType() != ops[i + 1]->getOpType())) {
|
||||||
|
newOps.push_back(ops[i]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ops[i]->getOpType() == OpType::Reshape) {
|
||||||
|
newOps.push_back(make_ref<ReshapeObj>(
|
||||||
|
nullptr, ops[i]->getInputs(0), ops[i + 1]->getOutput()));
|
||||||
|
++i;
|
||||||
|
haveElimination = true;
|
||||||
|
} else if (ops[i]->getOpType() == OpType::Transpose) {
|
||||||
|
auto permuteA = as<TransposeObj>(ops[i])->getPermute();
|
||||||
|
auto permuteB = as<TransposeObj>(ops[i + 1])->getPermute();
|
||||||
|
vector<int> permute;
|
||||||
|
for (auto p : permuteB)
|
||||||
|
permute.push_back(permuteA[p]);
|
||||||
|
newOps.push_back(
|
||||||
|
make_ref<TransposeObj>(nullptr, ops[i]->getInputs(0),
|
||||||
|
ops[i + 1]->getOutput(), permute));
|
||||||
|
++i;
|
||||||
|
haveElimination = true;
|
||||||
|
} else {
|
||||||
|
newOps.push_back(ops[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g = constructGraphByOperatorChain(newOps, inputGraph);
|
||||||
|
} while (haveElimination);
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
Graph NMutator::fuseVertically(const Graph &inputGraph) {
|
Graph NMutator::fuseVertically(const Graph &inputGraph) {
|
||||||
Graph optGraph = make_ref<GraphObj>(runtime);
|
Graph optGraph = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
@ -780,7 +917,7 @@ Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||||
shapeNew.emplace_back(shapeOrignial[i]);
|
shapeNew.emplace_back(shapeOrignial[i]);
|
||||||
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
|
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
|
||||||
auto A2 =
|
auto A2 =
|
||||||
g->addOp<TransposeObj>(A1, nullptr, vector{0, 1, 3, 2})->getOutput();
|
g->addOp<TransposeObj>(A1, nullptr, vector{0, 2, 1, 3})->getOutput();
|
||||||
Tensor A3;
|
Tensor A3;
|
||||||
if (output)
|
if (output)
|
||||||
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)
|
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)
|
||||||
|
|
|
@ -20,9 +20,9 @@ optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
|
||||||
std::string ReshapeObj::toString() const {
|
std::string ReshapeObj::toString() const {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "Reshape[" << getGuid() << "]";
|
os << "Reshape[" << getGuid() << "]";
|
||||||
os << "(";
|
os << "(input dim=";
|
||||||
os << vecToString(inputs[0]->getDims()) << ",";
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
os << "dims=" << vecToString(dims) << ",";
|
os << "output dims=" << vecToString(dims) << ",";
|
||||||
os << "input=" << inputs[0]->getGuid() << ",";
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
os << "output=" << outputs[0]->getGuid() << ")";
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
return os.str();
|
return os.str();
|
||||||
|
|
|
@ -27,7 +27,8 @@ std::string TransposeObj::toString() const {
|
||||||
os << "(";
|
os << "(";
|
||||||
os << vecToString(inputs[0]->getDims()) << ",";
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
os << "input=" << inputs[0]->getGuid() << ",";
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
os << "output=" << outputs[0]->getGuid() << ")";
|
os << "output=" << outputs[0]->getGuid() << ",";
|
||||||
|
os << "perm=" << vecToString(transposePermute) << ")";
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,9 @@
|
||||||
#include "nnet/nmutator.h"
|
#include "nnet/nmutator.h"
|
||||||
#include "nnet/test.h"
|
#include "nnet/test.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
|
#include "operators/reshape.h"
|
||||||
|
#include "operators/softmax.h"
|
||||||
|
#include "operators/transpose.h"
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -479,4 +482,73 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) {
|
||||||
// bestGraph->print();
|
// bestGraph->print();
|
||||||
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
|
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
TEST(NMutator, eliminateVertically_RTSTR) {
|
||||||
|
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
const int a = 8, b = 4, c = 5, d = 30;
|
||||||
|
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||||
|
auto input = t0;
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||||
|
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||||
|
auto mutator = make_ref<NMutator>();
|
||||||
|
auto optG = mutator->eliminateVertically(g);
|
||||||
|
dbg(optG);
|
||||||
|
ASSERT_EQ(optG->getOperators().size(), 1u);
|
||||||
|
auto op = optG->getOperators()[0];
|
||||||
|
EXPECT_EQ(op->getOpType(), OpType::Softmax);
|
||||||
|
EXPECT_EQ(op->getInputs(0)->getFuid(), input->getFuid());
|
||||||
|
EXPECT_EQ(op->getOutput()->getFuid(), t0->getFuid());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NMutator, eliminateVertically_RTST) {
|
||||||
|
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
const int a = 8, b = 4, c = 5, d = 30;
|
||||||
|
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||||
|
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||||
|
auto mutator = make_ref<NMutator>();
|
||||||
|
auto optG = mutator->eliminateVertically(g);
|
||||||
|
dbg(optG);
|
||||||
|
ASSERT_EQ(optG->getOperators().size(), 2u);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NMutator, eliminateVertically_RTSTR_3d) {
|
||||||
|
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
const int a = 8, b = 4, c = 5, d = 30;
|
||||||
|
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
|
||||||
|
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||||
|
auto mutator = make_ref<NMutator>();
|
||||||
|
auto optG = mutator->eliminateVertically(g);
|
||||||
|
dbg(optG);
|
||||||
|
EXPECT_EQ(optG->getOperators().size(), 1u);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
|
||||||
|
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
const int a = 8, b = 4, c = 5, d = 30;
|
||||||
|
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
|
||||||
|
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 2)->getOutput();
|
||||||
|
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
|
||||||
|
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||||
|
auto mutator = make_ref<NMutator>();
|
||||||
|
auto optG = mutator->eliminateVertically(g);
|
||||||
|
dbg(optG);
|
||||||
|
EXPECT_EQ(optG->getOperators().size(), 5u);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue