forked from jiuyuan/InfiniTensor
Fix: fuse Relu to GEMM/Conv
This commit is contained in:
parent
a1f02593d3
commit
df2534d209
|
@ -3,6 +3,7 @@
|
|||
#include "core/runtime.h"
|
||||
#include "ffi/ffi_callback.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "operators/reshape.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
@ -505,13 +506,18 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
ops.emplace_back(op);
|
||||
continue;
|
||||
}
|
||||
// FIXME: fuse and modify attributes of computing operators
|
||||
if (op->getOpType() == OpType::Relu ||
|
||||
op->getOpType() == OpType::PRelu) {
|
||||
if (auto p = op->getInputs()[0])
|
||||
if (auto sop = p->getSource())
|
||||
if (sop->getOpType() == OpType::Conv ||
|
||||
sop->getOpType() == OpType::Matmul)
|
||||
sop->getOpType() == OpType::Matmul) {
|
||||
visitTime.emplace(op->getGuid(), ++cnt);
|
||||
ops.emplace_back(make_ref<ReshapeObj>(
|
||||
nullptr, op->getInputs()[0], op->getOutputs()[0]));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
vector<Operator> chainOps;
|
||||
visitTime.emplace(op->getGuid(), ++cnt);
|
||||
|
@ -521,6 +527,8 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
while (cur->getPredecessors().size() == 1 &&
|
||||
cur->getPredecessors()[0]->isMemBoundOp()) {
|
||||
cur = cur->getPredecessors()[0];
|
||||
if (visitTime.count(cur->getGuid()))
|
||||
break;
|
||||
tmp.emplace_back(cur);
|
||||
visitTime.emplace(cur->getGuid(), cnt);
|
||||
}
|
||||
|
@ -532,6 +540,8 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
while (cur->getSuccessors().size() == 1 &&
|
||||
cur->getSuccessors()[0]->isMemBoundOp()) {
|
||||
cur = cur->getSuccessors()[0];
|
||||
if (visitTime.count(cur->getGuid()))
|
||||
break;
|
||||
chainOps.emplace_back(cur);
|
||||
visitTime.emplace(cur->getGuid(), cnt);
|
||||
}
|
||||
|
@ -539,7 +549,6 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
|
||||
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
|
||||
// Eliminate transpose and reshape operators
|
||||
// FIXME: current Relu only support 3D and 4D tensors
|
||||
if (auto eliminatedGraph = mutator->eliminateVertically(
|
||||
make_ref<GraphObj>(runtimeExec, chainOps)))
|
||||
bestGraph = eliminatedGraph;
|
||||
|
@ -550,7 +559,15 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
ops.emplace_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.empty()) {
|
||||
IT_TODO_HALT();
|
||||
IT_ASSERT(graph->getOutputs().size() == 1);
|
||||
IT_ASSERT(graph->getInputs().size() == 1);
|
||||
// auto g = make_ref<GraphObj>(runtime);
|
||||
// TODO: add identity
|
||||
ops.emplace_back(make_ref<ReshapeObj>(nullptr, graph->getInputs()[0],
|
||||
graph->getOutputs()[0]));
|
||||
}
|
||||
return make_ref<GraphObj>(runtimeExec, ops);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue