Add: rule of Any+Relu->Any

This commit is contained in:
Liyan Zheng 2023-04-28 21:46:54 +08:00
parent d0ae48d21d
commit c6e7748786
5 changed files with 72 additions and 12 deletions

View File

@ -48,6 +48,14 @@ class GraphObj : public Object {
return opClone;
}
Operator cloneOpAndCreateOutputs(Operator op, TensorVec inputs) {
auto shapes = *op->inferShape(inputs);
vector<Tensor> outputs;
for (auto shape : shapes)
outputs.emplace_back(addTensor(shape));
return cloneOperator(op, inputs, outputs);
}
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;

View File

@ -22,6 +22,7 @@ class AnyObj : public OperatorObj {
int numOutputs() const override { return outputs.size(); }
const string getKernelName() const;
void setAttr(int i, int v) { attr[i] = v; }
vector<int> getOpAttrVector() const override;
vector<int> getWorkloadVector() const override;
};

View File

@ -35,6 +35,8 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
bool OperatorObj::isMemBoundOp() const {
if (type == OpType::Any)
return true; // TODO: check operator attributes
return type == OpType::MemBound || type == OpType::Reshape ||
type == OpType::Activation || type == OpType::Transpose ||
type == OpType::Relu || type == OpType::Tanh ||

View File

@ -928,19 +928,28 @@ Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
IT_ASSERT(inputGraph->getInputs().size() == 1);
IT_ASSERT(inputGraph->getOutputs().size() == 1);
IT_ASSERT(ops.size() > 0,
"TODO: If there is no op left, how to return an empty graph?");
"TODO: If there is no op left, how to return an empty graph? " +
inputGraph->toString());
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
auto graphOutput = g->cloneTensor(inputGraph->getOutputs()[0]);
for (size_t i = 0; i < ops.size(); ++i) {
auto output = (i + 1 == ops.size())
? inputGraph->getOutputs()[0]
: g->addTensor(ops[i]->getOutput()->getDims());
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
if (i + 1 == ops.size() &&
ops[i]->getOutput()->getDims() == graphOutput->getDims()) {
input =
g->cloneOperator(ops[i], {input}, {graphOutput})->getOutput();
} else { // If it is not the last op or output shape dismatches
input = g->cloneOpAndCreateOutputs(ops[i], {input})->getOutput();
}
}
// Add a reshape to match original graph if necessary
if (g->getOutputs()[0]->getDims() != graphOutput->getDims())
g->addOpWithOutputs<ReshapeObj>(input, graphOutput);
return g;
}
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
auto ops = inputGraph->getOperators();
bool funcHasOptmization = false;
IT_ASSERT(!ops.empty());
for (auto &op : ops) {
@ -983,17 +992,19 @@ Graph NMutator::eliminateVertically(const Graph &inputGraph) {
return tuple{isComputation, isElementwise, lastRowSwapable};
};
// Reorder operators: move computatation operators to the tail
// Reorder operators: move computatation operators to the head
for (int i = ops.size() - 2; i >= 0; --i) {
for (int j = i; j < int(ops.size()) - 1; ++j) {
bool swapable = false;
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
OpType::Tanh};
bool swapable = false;
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
// check swapable conditions:
// (!aIsC && bIsC): ordering of computation and non-computation
// (aEw && aEw): elementwise
// (aLRS && bLRS): last dim fixed
if ((!aIsC && bIsC) && ((aEw && bEw) || (aLRS && bLRS)))
swapable = true;
if (swapable) {
if (swapable) {
swap(ops[j], ops[j + 1]);
}
}
@ -1001,8 +1012,9 @@ Graph NMutator::eliminateVertically(const Graph &inputGraph) {
Graph g = constructGraphByOperatorChain(ops, inputGraph);
// Eliminate operators
bool haveElimination;
bool haveElimination = false;
do {
funcHasOptmization = funcHasOptmization || haveElimination;
haveElimination = false;
ops = g->getOperators();
vector<Operator> newOps;
@ -1024,6 +1036,24 @@ Graph NMutator::eliminateVertically(const Graph &inputGraph) {
continue;
}
}
// Operator-level fusion
// Any+Relu -> Any(activation=1)
if (i + 1 < int(ops.size())) {
const string name = "reduceConvRxSToNCHW";
if (auto op = as<AnyObj>(ops[i]);
op && op->getKernelName() == name) {
if (auto op2 = as<ReluObj>(ops[i + 1])) {
if (op->getOutput() == op2->getInputs(0)) {
auto newOp = make_ref<AnyObj>(*op);
newOp->setAttr(0, 1); // Set activation
newOps.push_back(newOp);
++i;
haveElimination = true;
continue;
}
}
}
}
// Eliminate reciprocal operators
if (i + 1 == (int)ops.size() ||

View File

@ -6,10 +6,12 @@
#include "cuda/cuda_runtime.h"
#include "nnet/nmutator.h"
#include "nnet/test.h"
#include "operators/any.h"
#include "operators/conv.h"
#include "operators/reshape.h"
#include "operators/softmax.h"
#include "operators/transpose.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
@ -551,4 +553,21 @@ TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
EXPECT_EQ(optG->getOperators().size(), 5u);
}
TEST(NMutator, eliminateVertically_Reduce_Reshape_Relu) {
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
const int a = 8;
auto input = g->addTensor({a, a});
auto t0 = g->addTensor({a, a});
vector<int> args(15, 0);
const string kernelName = "reduceConvRxSToNCHW";
g->addOpWithOutputs<AnyObj>(vector{input}, vector{t0}, kernelName, args);
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a * a})->getOutput();
t0 = g->addOp<ReluObj>(t0, nullptr)->getOutput();
auto mutator = make_ref<NMutator>();
auto optG = mutator->eliminateVertically(g);
dbg(optG);
EXPECT_EQ(optG->getOperators().size(), 2u);
}
} // namespace infini