Merge branch 'NNET_eliminateOP' into NNET_e2e

This commit is contained in:
Liyan Zheng 2023-04-23 13:35:29 +08:00
commit 5df2524ff9
10 changed files with 250 additions and 16 deletions

View File

@ -30,7 +30,14 @@ class Mutator {
virtual bool isMultiBranchMergable(const Graph &in_graph) {
IT_TODO_HALT();
}
/// @brief Fuse memory bound operators.
/// @return The graph after fusion. Return `nullptr` if fails.
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
/// @brief Eliminate transpose and reshape.
/// @return The graph after elimination. Return `nullptr` if fails.
virtual Graph eliminateVertically(const Graph &in_graph) { IT_TODO_HALT(); }
};
} // namespace infini

View File

@ -26,6 +26,7 @@ class NMutator : public Mutator {
vector<Graph> run(const Graph &in_graph) override;
Graph fuseVertically(const Graph &in_graph) override;
Graph eliminateVertically(const Graph &in_graph) override;
bool isMultiBranchMergable(const Graph &in_graph) override;
void setToNaiveMembound();
@ -68,6 +69,14 @@ class NMutator : public Mutator {
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
Tensor output = nullptr);
/// @brief Construct a new graph with a chain of operators. Use the output
/// from the previous operator as the input of the next operator. While
/// constructing, the input and output tensors from inputGraph are used as
/// new constructed graph.
/// @param op The operator chain. It can have wrong input/output shapes.
/// @return
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
};
} // namespace infini

View File

@ -3,6 +3,8 @@
namespace infini {
class TransposeObj : public OperatorObj {
vector<int> transposePermute;
public:
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
vector<int> permute);
@ -15,7 +17,6 @@ class TransposeObj : public OperatorObj {
std::vector<int> getPermute() const { return transposePermute; }
private:
vector<int> transposePermute;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};

View File

@ -34,9 +34,10 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorObj::isMemBoundOp() const {
return type == OpType::MemBound || type == OpType::Activation ||
type == OpType::Transpose || type == OpType::Relu ||
type == OpType::Tanh;
return type == OpType::MemBound || type == OpType::Reshape ||
type == OpType::Activation || type == OpType::Transpose ||
type == OpType::Relu || type == OpType::Tanh ||
type == OpType::Softmax;
}
void OperatorObj::removePredecessors(const Operator &op) {

View File

@ -496,9 +496,15 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
}
make_ref<GraphObj>(runtimeExec, chainOps)->print();
Graph optGraph =
mutator->fuseVertically(make_ref<GraphObj>(runtimeExec, chainOps));
for (auto op : optGraph->getOperators()) {
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
// Eliminate transpose and reshape operators
if (auto eliminatedGraph = mutator->eliminateVertically(
make_ref<GraphObj>(runtimeExec, chainOps)))
bestGraph = eliminatedGraph;
// Fuse membound operators
if (auto optGraph = mutator->fuseVertically(bestGraph))
bestGraph = optGraph;
for (auto op : bestGraph->getOperators()) {
ops.emplace_back(op);
}
}

View File

@ -17,9 +17,7 @@ CudaRuntimeObj::CudaRuntimeObj()
checkCublasError(cublasCreate(&cublas));
checkCudnnError(cudnnSetStream(cudnn, stream));
checkCublasError(cublasSetStream(cublas, stream));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
workspaceSize = 2ll << 30; // 2 GB
workspace = alloc(workspaceSize);
}

View File

@ -12,6 +12,7 @@
#include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/reshape.h"
#include "operators/softmax.h"
#include "operators/transpose.h"
#include "operators/unary.h"
@ -248,6 +249,8 @@ void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
nnet::Expr NMutator::opToExpression(Operator opT) {
auto [expr, mapNameNToTensorT] = extractOp(opT);
IT_ASSERT(expr,
"Cannot convert " + opT->toString() + " to an NNet expression");
for (auto &[name, tensorT] : mapNameNToTensorT) {
IT_ASSERT(inputsNameNToTensorT.count(name) == 0);
inputsNameNToTensorT[name] = tensorT;
@ -331,8 +334,6 @@ pair<nnet::Expr, NMutator::NameNToTensorT> NMutator::extractOp(Operator opT) {
// // else if (auto transposeOp = dynamic_cast<TransposeOp *>(opT)) {
// // return transposeOpToExpression(transposeOp);
// // }
IT_TODO_HALT_MSG("Cannot convert " + opT->toString() +
" to an NNet expression");
return {};
}
@ -684,6 +685,142 @@ Graph NMutator::transformConv1x1(Operator _op) {
// return graph;
// }
Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
Graph inputGraph) {
// Construct new graph
auto g = make_ref<GraphObj>(runtime);
IT_ASSERT(inputGraph->getInputs().size() == 1);
IT_ASSERT(inputGraph->getOutputs().size() == 1);
IT_ASSERT(ops.size() > 0,
"TODO: If there is no op left, how to return an empty graph?");
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
for (size_t i = 0; i < ops.size(); ++i) {
auto output = (i + 1 == ops.size())
? inputGraph->getOutputs()[0]
: g->addTensor(ops[i]->getOutput()->getDims());
dbg(input->getDims(), output->getDims());
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
}
return g;
}
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
auto ops = inputGraph->getOperators();
IT_ASSERT(!ops.empty());
for (auto &op : ops) {
IT_ASSERT(op->isMemBoundOp());
IT_ASSERT_TODO(op->getInputs().size() == 1);
IT_ASSERT(op->getOutputs().size() == 1);
}
if (ops.size() == 1) {
return make_ref<GraphObj>(runtime, ops);
}
// Set attributs for operators.
// isComputation: is computaiton
// isElementwise: do elementwise computations
// lastRowSwapable: do last-channel-wise computations, which includes
// elementwise as a special case.
auto classifyOperator = [](Operator op) {
auto type = op->getOpType();
bool isComputation =
type != OpType::Reshape && type != OpType::Transpose;
bool isElementwise =
!isComputation || (type == OpType::Relu || type == OpType::Tanh);
bool lastRowSwapable = false;
if (isComputation)
lastRowSwapable = isElementwise || // Softmax along the last dim
(type == OpType::Softmax &&
as<SoftmaxObj>(op)->getAxis() ==
int(op->getOutput()->getDims().size()) - 1);
else {
if (auto t = as<TransposeObj>(op)) {
// Last dim remains unchanged
lastRowSwapable =
(t->getPermute().back() == int(t->getPermute().size()) - 1);
} else if (auto t = as<ReshapeObj>(op)) {
// Last dim remains unchanged
lastRowSwapable = (t->getInputs(0)->getDims().back() ==
t->getOutput()->getDims().back());
}
}
return tuple{isComputation, isElementwise, lastRowSwapable};
};
// Reorder operators: move computatation operators to the tail
for (int i = ops.size() - 2; i >= 0; --i) {
for (int j = i; j < int(ops.size()) - 1; ++j) {
bool swapable = false;
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
OpType::Tanh};
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
swapable = true;
if (swapable) {
swap(ops[j], ops[j + 1]);
}
}
}
Graph g = constructGraphByOperatorChain(ops, inputGraph);
// Eliminate operators
bool haveElimination;
do {
haveElimination = false;
ops = g->getOperators();
vector<Operator> newOps;
for (int i = 0; i < int(ops.size()); ++i) {
// Eliminate identity operators
if (auto op = as<TransposeObj>(ops[i])) {
auto perm = op->getPermute();
int j = 0;
for (j = 0; j < int(perm.size()); ++j)
if (j != perm[j])
break;
if (j == int(perm.size())) {
haveElimination = true;
continue;
}
} else if (auto op = as<ReshapeObj>(ops[i])) {
if (op->getShape() == op->getInputs(0)->getDims()) {
haveElimination = true;
continue;
}
}
// Eliminate reciprocal operators
if (i + 1 == (int)ops.size() ||
(ops[i]->getOpType() != ops[i + 1]->getOpType())) {
newOps.push_back(ops[i]);
continue;
}
if (ops[i]->getOpType() == OpType::Reshape) {
newOps.push_back(make_ref<ReshapeObj>(
nullptr, ops[i]->getInputs(0), ops[i + 1]->getOutput()));
++i;
haveElimination = true;
} else if (ops[i]->getOpType() == OpType::Transpose) {
auto permuteA = as<TransposeObj>(ops[i])->getPermute();
auto permuteB = as<TransposeObj>(ops[i + 1])->getPermute();
vector<int> permute;
for (auto p : permuteB)
permute.push_back(permuteA[p]);
newOps.push_back(
make_ref<TransposeObj>(nullptr, ops[i]->getInputs(0),
ops[i + 1]->getOutput(), permute));
++i;
haveElimination = true;
} else {
newOps.push_back(ops[i]);
}
}
g = constructGraphByOperatorChain(newOps, inputGraph);
} while (haveElimination);
return g;
}
Graph NMutator::fuseVertically(const Graph &inputGraph) {
Graph optGraph = make_ref<GraphObj>(runtime);
@ -700,6 +837,8 @@ Graph NMutator::fuseVertically(const Graph &inputGraph) {
std::vector<nnet::Expr> exprs;
for (const auto &op : chainOps) {
auto [expr, _] = extractOp(op);
if (!expr)
return nullptr;
exprs.emplace_back(expr);
// dbg(op, infini::as<nnet::RangeOpNode>(expr)->getFullExpression());
}
@ -780,7 +919,7 @@ Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
shapeNew.emplace_back(shapeOrignial[i]);
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
auto A2 =
g->addOp<TransposeObj>(A1, nullptr, vector{0, 1, 3, 2})->getOutput();
g->addOp<TransposeObj>(A1, nullptr, vector{0, 2, 1, 3})->getOutput();
Tensor A3;
if (output)
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)

View File

@ -20,9 +20,9 @@ optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
std::string ReshapeObj::toString() const {
std::ostringstream os;
os << "Reshape[" << getGuid() << "]";
os << "(";
os << "(input dim=";
os << vecToString(inputs[0]->getDims()) << ",";
os << "dims=" << vecToString(dims) << ",";
os << "output dims=" << vecToString(dims) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();

View File

@ -27,7 +27,8 @@ std::string TransposeObj::toString() const {
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
os << "output=" << outputs[0]->getGuid() << ",";
os << "perm=" << vecToString(transposePermute) << ")";
return os.str();
}

View File

@ -7,6 +7,9 @@
#include "nnet/nmutator.h"
#include "nnet/test.h"
#include "operators/conv.h"
#include "operators/reshape.h"
#include "operators/softmax.h"
#include "operators/transpose.h"
#include "test.h"
namespace infini {
@ -479,4 +482,73 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) {
// bestGraph->print();
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
// }
TEST(NMutator, eliminateVertically_RTSTR) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
auto input = t0;
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
ASSERT_EQ(optG->getOperators().size(), 1u);
auto op = optG->getOperators()[0];
EXPECT_EQ(op->getOpType(), OpType::Softmax);
EXPECT_EQ(op->getInputs(0)->getFuid(), input->getFuid());
EXPECT_EQ(op->getOutput()->getFuid(), t0->getFuid());
}
TEST(NMutator, eliminateVertically_RTST) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
ASSERT_EQ(optG->getOperators().size(), 2u);
}
TEST(NMutator, eliminateVertically_RTSTR_3d) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
EXPECT_EQ(optG->getOperators().size(), 1u);
}
TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8, b = 4, c = 5, d = 30;
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 2)->getOutput();
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
EXPECT_EQ(optG->getOperators().size(), 5u);
}
} // namespace infini