forked from jiuyuan/InfiniTensor
Add: rule of Any+Relu->Any
This commit is contained in:
parent
d0ae48d21d
commit
c6e7748786
|
@ -48,6 +48,14 @@ class GraphObj : public Object {
|
|||
return opClone;
|
||||
}
|
||||
|
||||
Operator cloneOpAndCreateOutputs(Operator op, TensorVec inputs) {
|
||||
auto shapes = *op->inferShape(inputs);
|
||||
vector<Tensor> outputs;
|
||||
for (auto shape : shapes)
|
||||
outputs.emplace_back(addTensor(shape));
|
||||
return cloneOperator(op, inputs, outputs);
|
||||
}
|
||||
|
||||
const TensorVec &getTensors() const { return tensors; }
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
|
|
|
@ -22,6 +22,7 @@ class AnyObj : public OperatorObj {
|
|||
int numOutputs() const override { return outputs.size(); }
|
||||
|
||||
const string getKernelName() const;
|
||||
void setAttr(int i, int v) { attr[i] = v; }
|
||||
vector<int> getOpAttrVector() const override;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
};
|
||||
|
|
|
@ -35,6 +35,8 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
|||
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
|
||||
|
||||
bool OperatorObj::isMemBoundOp() const {
|
||||
if (type == OpType::Any)
|
||||
return true; // TODO: check operator attributes
|
||||
return type == OpType::MemBound || type == OpType::Reshape ||
|
||||
type == OpType::Activation || type == OpType::Transpose ||
|
||||
type == OpType::Relu || type == OpType::Tanh ||
|
||||
|
|
|
@ -928,19 +928,28 @@ Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
|
|||
IT_ASSERT(inputGraph->getInputs().size() == 1);
|
||||
IT_ASSERT(inputGraph->getOutputs().size() == 1);
|
||||
IT_ASSERT(ops.size() > 0,
|
||||
"TODO: If there is no op left, how to return an empty graph?");
|
||||
"TODO: If there is no op left, how to return an empty graph? " +
|
||||
inputGraph->toString());
|
||||
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
|
||||
auto graphOutput = g->cloneTensor(inputGraph->getOutputs()[0]);
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
auto output = (i + 1 == ops.size())
|
||||
? inputGraph->getOutputs()[0]
|
||||
: g->addTensor(ops[i]->getOutput()->getDims());
|
||||
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
|
||||
if (i + 1 == ops.size() &&
|
||||
ops[i]->getOutput()->getDims() == graphOutput->getDims()) {
|
||||
input =
|
||||
g->cloneOperator(ops[i], {input}, {graphOutput})->getOutput();
|
||||
} else { // If it is not the last op or output shape dismatches
|
||||
input = g->cloneOpAndCreateOutputs(ops[i], {input})->getOutput();
|
||||
}
|
||||
}
|
||||
// Add a reshape to match original graph if necessary
|
||||
if (g->getOutputs()[0]->getDims() != graphOutput->getDims())
|
||||
g->addOpWithOutputs<ReshapeObj>(input, graphOutput);
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
||||
auto ops = inputGraph->getOperators();
|
||||
bool funcHasOptmization = false;
|
||||
|
||||
IT_ASSERT(!ops.empty());
|
||||
for (auto &op : ops) {
|
||||
|
@ -983,17 +992,19 @@ Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
|||
return tuple{isComputation, isElementwise, lastRowSwapable};
|
||||
};
|
||||
|
||||
// Reorder operators: move computatation operators to the tail
|
||||
// Reorder operators: move computatation operators to the head
|
||||
for (int i = ops.size() - 2; i >= 0; --i) {
|
||||
for (int j = i; j < int(ops.size()) - 1; ++j) {
|
||||
bool swapable = false;
|
||||
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
|
||||
OpType::Tanh};
|
||||
bool swapable = false;
|
||||
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
|
||||
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
|
||||
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
|
||||
// check swapable conditions:
|
||||
// (!aIsC && bIsC): ordering of computation and non-computation
|
||||
// (aEw && aEw): elementwise
|
||||
// (aLRS && bLRS): last dim fixed
|
||||
if ((!aIsC && bIsC) && ((aEw && bEw) || (aLRS && bLRS)))
|
||||
swapable = true;
|
||||
if (swapable) {
|
||||
if (swapable) {
|
||||
swap(ops[j], ops[j + 1]);
|
||||
}
|
||||
}
|
||||
|
@ -1001,8 +1012,9 @@ Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
|||
|
||||
Graph g = constructGraphByOperatorChain(ops, inputGraph);
|
||||
// Eliminate operators
|
||||
bool haveElimination;
|
||||
bool haveElimination = false;
|
||||
do {
|
||||
funcHasOptmization = funcHasOptmization || haveElimination;
|
||||
haveElimination = false;
|
||||
ops = g->getOperators();
|
||||
vector<Operator> newOps;
|
||||
|
@ -1024,6 +1036,24 @@ Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
|||
continue;
|
||||
}
|
||||
}
|
||||
// Operator-level fusion
|
||||
// Any+Relu -> Any(activation=1)
|
||||
if (i + 1 < int(ops.size())) {
|
||||
const string name = "reduceConvRxSToNCHW";
|
||||
if (auto op = as<AnyObj>(ops[i]);
|
||||
op && op->getKernelName() == name) {
|
||||
if (auto op2 = as<ReluObj>(ops[i + 1])) {
|
||||
if (op->getOutput() == op2->getInputs(0)) {
|
||||
auto newOp = make_ref<AnyObj>(*op);
|
||||
newOp->setAttr(0, 1); // Set activation
|
||||
newOps.push_back(newOp);
|
||||
++i;
|
||||
haveElimination = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate reciprocal operators
|
||||
if (i + 1 == (int)ops.size() ||
|
||||
|
|
|
@ -6,10 +6,12 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "nnet/test.h"
|
||||
#include "operators/any.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -551,4 +553,21 @@ TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
|
|||
EXPECT_EQ(optG->getOperators().size(), 5u);
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_Reduce_Reshape_Relu) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8;
|
||||
auto input = g->addTensor({a, a});
|
||||
auto t0 = g->addTensor({a, a});
|
||||
vector<int> args(15, 0);
|
||||
const string kernelName = "reduceConvRxSToNCHW";
|
||||
g->addOpWithOutputs<AnyObj>(vector{input}, vector{t0}, kernelName, args);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a * a})->getOutput();
|
||||
t0 = g->addOp<ReluObj>(t0, nullptr)->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
EXPECT_EQ(optG->getOperators().size(), 2u);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue