Fix: fuse Relu to GEMM/Conv

This commit is contained in:
Liyan Zheng 2023-04-30 16:16:16 +08:00
parent a1f02593d3
commit df2534d209
1 changed files with 20 additions and 3 deletions

View File

@ -3,6 +3,7 @@
#include "core/runtime.h"
#include "ffi/ffi_callback.h"
#include "nnet/dbg.h"
#include "operators/reshape.h"
#include <algorithm>
#include <iostream>
@ -505,13 +506,18 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
ops.emplace_back(op);
continue;
}
// FIXME: fuse and modify attributes of computing operators
if (op->getOpType() == OpType::Relu ||
op->getOpType() == OpType::PRelu) {
if (auto p = op->getInputs()[0])
if (auto sop = p->getSource())
if (sop->getOpType() == OpType::Conv ||
sop->getOpType() == OpType::Matmul)
sop->getOpType() == OpType::Matmul) {
visitTime.emplace(op->getGuid(), ++cnt);
ops.emplace_back(make_ref<ReshapeObj>(
nullptr, op->getInputs()[0], op->getOutputs()[0]));
continue;
}
}
vector<Operator> chainOps;
visitTime.emplace(op->getGuid(), ++cnt);
@ -521,6 +527,8 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
while (cur->getPredecessors().size() == 1 &&
cur->getPredecessors()[0]->isMemBoundOp()) {
cur = cur->getPredecessors()[0];
if (visitTime.count(cur->getGuid()))
break;
tmp.emplace_back(cur);
visitTime.emplace(cur->getGuid(), cnt);
}
@ -532,6 +540,8 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
while (cur->getSuccessors().size() == 1 &&
cur->getSuccessors()[0]->isMemBoundOp()) {
cur = cur->getSuccessors()[0];
if (visitTime.count(cur->getGuid()))
break;
chainOps.emplace_back(cur);
visitTime.emplace(cur->getGuid(), cnt);
}
@ -539,7 +549,6 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
// Eliminate transpose and reshape operators
// FIXME: current Relu only support 3D and 4D tensors
if (auto eliminatedGraph = mutator->eliminateVertically(
make_ref<GraphObj>(runtimeExec, chainOps)))
bestGraph = eliminatedGraph;
@ -550,7 +559,15 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
ops.emplace_back(op);
}
}
if (ops.empty()) {
IT_TODO_HALT();
IT_ASSERT(graph->getOutputs().size() == 1);
IT_ASSERT(graph->getInputs().size() == 1);
// auto g = make_ref<GraphObj>(runtime);
// TODO: add identity
ops.emplace_back(make_ref<ReshapeObj>(nullptr, graph->getInputs()[0],
graph->getOutputs()[0]));
}
return make_ref<GraphObj>(runtimeExec, ops);
}