forked from jiuyuan/InfiniTensor
ADD: sub graph replacement. (#56)
reconfig: connections among op and tensor now is managered by GraphObj . add some comments merge from master merge from master ADD: sub graph replacement reconfig inputs of op resize, due to the check of operator inputs. ResizeObj::clone clang format fix some and add test for multi-output. replacement support multi-inputs and multi-outputs. add clone for all operators add replaceSubGraph addSubGraph remove extra code add more test remove extra print Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
c8b2c8ed32
commit
43d4798323
|
@ -22,6 +22,27 @@ class GraphObj : public Object {
|
||||||
Tensor cloneTensor(const Tensor &tensor) {
|
Tensor cloneTensor(const Tensor &tensor) {
|
||||||
return addTensor(tensor->clone(runtime));
|
return addTensor(tensor->clone(runtime));
|
||||||
}
|
}
|
||||||
|
void removeOperator(Operator op) {
|
||||||
|
auto it = std::find(ops.begin(), ops.end(), op);
|
||||||
|
if (it != ops.end())
|
||||||
|
ops.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
void removeTensor(Tensor tensor) {
|
||||||
|
auto it = std::find(tensors.begin(), tensors.end(), tensor);
|
||||||
|
if (it != tensors.end())
|
||||||
|
tensors.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deleteConnection(Tensor tensor, Operator op);
|
||||||
|
void addConnection(Tensor tensor, Operator op);
|
||||||
|
void replaceConnection(Tensor oldInput, Tensor newInput, Operator op);
|
||||||
|
|
||||||
|
Operator cloneOperator(Operator op, TensorVec inputs, TensorVec outputs) {
|
||||||
|
auto opClone = op->clone(inputs, outputs);
|
||||||
|
addOperatorAndConnect(opClone);
|
||||||
|
return opClone;
|
||||||
|
}
|
||||||
|
|
||||||
const TensorVec &getTensors() const { return tensors; }
|
const TensorVec &getTensors() const { return tensors; }
|
||||||
const OpVec &getOperators() const { return ops; }
|
const OpVec &getOperators() const { return ops; }
|
||||||
|
@ -79,6 +100,8 @@ class GraphObj : public Object {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool checkValid() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Add reverse connections and Op relationship in ctor.
|
* @brief Add reverse connections and Op relationship in ctor.
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/graph.h"
|
||||||
|
namespace infini {
|
||||||
|
class SubGraphObj : public GraphObj {
|
||||||
|
TensorVec ins; // inputs from outer predecessors, orders are appointed.
|
||||||
|
TensorVec outs; // outputs to outer successors, orders are appointed.
|
||||||
|
|
||||||
|
public:
|
||||||
|
SubGraphObj(Runtime runtime, const TensorVec &inputs);
|
||||||
|
void setOutputs(const TensorVec &tensors) { outs = tensors; }
|
||||||
|
TensorVec getInputsFromOutside() const { return ins; }
|
||||||
|
TensorVec getOutputs2Outside() const { return outs; }
|
||||||
|
bool isInputFromOutside(Tensor t) const {
|
||||||
|
return std::find(ins.begin(), ins.end(), t) != ins.end();
|
||||||
|
}
|
||||||
|
bool isOutput2Outside(Tensor t) const {
|
||||||
|
return std::find(outs.begin(), outs.end(), t) != outs.end();
|
||||||
|
}
|
||||||
|
bool isHead(const Operator &op) const {
|
||||||
|
for (auto in : ins) {
|
||||||
|
auto ops = in->getTargets();
|
||||||
|
if (std::find(ops.begin(), ops.end(), op) != ops.end())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
bool isTail(const Operator &op) const {
|
||||||
|
for (auto out : outs) {
|
||||||
|
if (op == out->getSource())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
using SubGraph = Ref<SubGraphObj>;
|
||||||
|
|
||||||
|
// Describe a match for subgraph replacement.
|
||||||
|
class GraphMatchObj {
|
||||||
|
std::unordered_set<Operator> ops;
|
||||||
|
std::unordered_map<Operator, Operator> opMap; // anchor->pattern
|
||||||
|
std::unordered_map<Operator, Operator> opMapRevese; // pattern->anchor
|
||||||
|
std::unordered_map<Tensor, Tensor> tensorMap; // pattern->anchor
|
||||||
|
SubGraph pattern;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GraphMatchObj(SubGraph pattern) : pattern(pattern) {}
|
||||||
|
Ref<GraphMatchObj> clone();
|
||||||
|
void addOp(const Operator &anchorOp, const Operator &patternOp);
|
||||||
|
bool hasContained(const Operator &op) const { return opMap.count(op) > 0; }
|
||||||
|
bool hasMatched(const Operator &op) const {
|
||||||
|
return opMapRevese.count(op) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor getAnchorByPattern(const Tensor &t) {
|
||||||
|
IT_ASSERT(tensorMap.count(t) > 0);
|
||||||
|
return tensorMap.at(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator getAnchorByPattern(const Operator &op) {
|
||||||
|
IT_ASSERT(opMapRevese.count(op) > 0);
|
||||||
|
return opMapRevese.at(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorVec getInputs() const;
|
||||||
|
TensorVec getOutputs() const;
|
||||||
|
std::unordered_set<Operator> getOps() const { return ops; }
|
||||||
|
std::string toString() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void recordOutsideTensorMap(const Operator &patternOp,
|
||||||
|
const Operator &anchorOp);
|
||||||
|
};
|
||||||
|
using MatchGraph = Ref<GraphMatchObj>;
|
||||||
|
|
||||||
|
class SubGraphRewriter {
|
||||||
|
SubGraph pattern;
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
public:
|
||||||
|
SubGraphRewriter(Graph g) : graph(g) {}
|
||||||
|
vector<MatchGraph> findMatch(const SubGraph &pattern);
|
||||||
|
void replaceSubGraph(const SubGraph &pattern, const SubGraph &replacement);
|
||||||
|
TensorVec addSubGraph(const SubGraph &pattern, const TensorVec &inputs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void removeSubGraph(MatchGraph match);
|
||||||
|
bool MatchNode(const Operator &a, const Operator &b, bool isHead,
|
||||||
|
bool isTail) const;
|
||||||
|
OpLists matchInCandidates(const OpVec &ops, const Operator &opDst,
|
||||||
|
bool isHead, bool isTail);
|
||||||
|
bool findMatch(const MatchGraph &lastMatched, const Operator &opLastMatched,
|
||||||
|
const Operator &opDst, vector<MatchGraph> &matched);
|
||||||
|
bool findMatch2(const MatchGraph &lastMatched,
|
||||||
|
const Operator &opLastMatched, const Operator &opDst,
|
||||||
|
vector<MatchGraph> &matched);
|
||||||
|
void updateMatchedGraph(const MatchGraph &lastMatched, OpLists &opMatched,
|
||||||
|
vector<MatchGraph> &gMatched, Operator dst);
|
||||||
|
|
||||||
|
bool checkReplacement(const SubGraph &pattern, const SubGraph &other) const;
|
||||||
|
bool checkReplacement(const TensorVec &left, const TensorVec &right) const;
|
||||||
|
bool isReplacable(const Tensor &l, const Tensor &r) const;
|
||||||
|
bool checkOverlapsWithPreviousMatch(
|
||||||
|
const MatchGraph &match,
|
||||||
|
const std::unordered_set<Operator> &nodesToDelete) const;
|
||||||
|
bool checkMatchValid(const MatchGraph &match) const;
|
||||||
|
};
|
||||||
|
}; // namespace infini
|
|
@ -1,6 +1,5 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/tensor.h"
|
#include "core/tensor.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
enum class OpType {
|
enum class OpType {
|
||||||
|
@ -139,7 +138,10 @@ struct OpPerfKey {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class GraphObj;
|
||||||
class OperatorObj : public Object {
|
class OperatorObj : public Object {
|
||||||
|
friend class GraphObj;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
OpType type;
|
OpType type;
|
||||||
TensorVec inputs;
|
TensorVec inputs;
|
||||||
|
@ -189,8 +191,6 @@ class OperatorObj : public Object {
|
||||||
IT_ASSERT(i < outputs.size(), "Index exceeded");
|
IT_ASSERT(i < outputs.size(), "Index exceeded");
|
||||||
return outputs.at(i);
|
return outputs.at(i);
|
||||||
}
|
}
|
||||||
void addPredecessors(const Operator &op) { predecessors.emplace_back(op); }
|
|
||||||
void addSuccessors(const Operator &op) { successors.emplace_back(op); }
|
|
||||||
OpVec getPredecessors() const { return wrefs_to_refs(predecessors); }
|
OpVec getPredecessors() const { return wrefs_to_refs(predecessors); }
|
||||||
OpVec getSuccessors() const { return wrefs_to_refs(successors); }
|
OpVec getSuccessors() const { return wrefs_to_refs(successors); }
|
||||||
OpType getOpType() const { return type; }
|
OpType getOpType() const { return type; }
|
||||||
|
@ -225,6 +225,12 @@ class OperatorObj : public Object {
|
||||||
* and output shapes.
|
* and output shapes.
|
||||||
*/
|
*/
|
||||||
virtual vector<int> getWorkloadVector() const { IT_TODO_HALT(); }
|
virtual vector<int> getWorkloadVector() const { IT_TODO_HALT(); }
|
||||||
|
|
||||||
|
void addPredecessors(const Operator &op) { predecessors.emplace_back(op); }
|
||||||
|
void addSuccessors(const Operator &op) { successors.emplace_back(op); }
|
||||||
|
void removePredecessors(const Operator &op);
|
||||||
|
void removeSuccessors(const Operator &op);
|
||||||
|
void replaceInput(Tensor t1, Tensor t2);
|
||||||
};
|
};
|
||||||
|
|
||||||
#define OP_CLONE(OpObj) \
|
#define OP_CLONE(OpObj) \
|
||||||
|
|
|
@ -25,6 +25,7 @@ enum class OpType;
|
||||||
|
|
||||||
using TensorVec = vector<Tensor>;
|
using TensorVec = vector<Tensor>;
|
||||||
using OpVec = vector<Operator>;
|
using OpVec = vector<Operator>;
|
||||||
|
using OpLists = list<Operator>;
|
||||||
|
|
||||||
using VType = uint32_t;
|
using VType = uint32_t;
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,11 @@
|
||||||
#include "core/data_type.h"
|
#include "core/data_type.h"
|
||||||
#include "core/object.h"
|
#include "core/object.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
class GraphObj;
|
||||||
class TensorBaseObj : public Object {
|
class TensorBaseObj : public Object {
|
||||||
|
friend class GraphObj;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// enum TensorType {
|
// enum TensorType {
|
||||||
// Input,
|
// Input,
|
||||||
|
@ -44,13 +46,24 @@ class TensorBaseObj : public Object {
|
||||||
DataType getDType() const { return dtype; }
|
DataType getDType() const { return dtype; }
|
||||||
Runtime getRuntime() const { return runtime; }
|
Runtime getRuntime() const { return runtime; }
|
||||||
|
|
||||||
void addTarget(const Operator &op) { targets.emplace_back(op); }
|
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||||
void setSource(const Operator &op) { source = op; }
|
|
||||||
|
|
||||||
bool hasTarget() const { return !targets.empty(); }
|
bool hasTarget() const { return !targets.empty(); }
|
||||||
|
|
||||||
OpVec getTargets() const { return wrefs_to_refs(targets); }
|
OpVec getTargets() const { return wrefs_to_refs(targets); }
|
||||||
Operator getSource() const { return source.lock(); }
|
Operator getSource() const { return source.lock(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void addTarget(const Operator &op) { targets.emplace_back(op); }
|
||||||
|
void setSource(const Operator &op) { source = op; }
|
||||||
|
void removeTarget(const Operator &op) {
|
||||||
|
for (auto itr = targets.begin(); itr != targets.end();) {
|
||||||
|
if (itr->lock() == op)
|
||||||
|
itr = targets.erase(itr);
|
||||||
|
else
|
||||||
|
++itr;
|
||||||
|
}
|
||||||
|
}
|
||||||
// std::pair<Operator *, int> getSourceWithIndex();
|
// std::pair<Operator *, int> getSourceWithIndex();
|
||||||
|
|
||||||
// bool setScalar(VType val) {
|
// bool setScalar(VType val) {
|
||||||
|
|
|
@ -35,7 +35,6 @@ class BatchNormObj : public OperatorObj {
|
||||||
float eps = 1e-5, bool training = false);
|
float eps = 1e-5, bool training = false);
|
||||||
OP_CLONE(BatchNormObj);
|
OP_CLONE(BatchNormObj);
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
|
|
||||||
// output size will be 3 when training
|
// output size will be 3 when training
|
||||||
|
|
|
@ -48,4 +48,4 @@ class SplitObj : public OperatorObj {
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
};
|
};
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -143,6 +143,71 @@ OpVec GraphObj::getComputeOps() const {
|
||||||
if (op->isComputeOp())
|
if (op->isComputeOp())
|
||||||
opList.emplace_back(op);
|
opList.emplace_back(op);
|
||||||
return opList;
|
return opList;
|
||||||
};
|
}
|
||||||
|
|
||||||
|
void GraphObj::deleteConnection(Tensor tensor, Operator op) {
|
||||||
|
// if op is target
|
||||||
|
IT_ASSERT(std::find(tensor->getTargets().begin(),
|
||||||
|
tensor->getTargets().end(),
|
||||||
|
op) != tensor->getTargets().end());
|
||||||
|
tensor->removeTarget(op);
|
||||||
|
if (tensor->getSource()) {
|
||||||
|
tensor->getSource()->removeSuccessors(op);
|
||||||
|
op->removePredecessors(tensor->getSource());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add op as a target
|
||||||
|
void GraphObj::addConnection(Tensor tensor, Operator op) {
|
||||||
|
tensor->addTarget(op);
|
||||||
|
if (tensor->getSource()) {
|
||||||
|
tensor->getSource()->addSuccessors(op);
|
||||||
|
op->addPredecessors(tensor->getSource());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GraphObj::replaceConnection(Tensor oldTensor, Tensor newTensor,
|
||||||
|
Operator op) {
|
||||||
|
// op is a target of old tensor
|
||||||
|
IT_ASSERT(std::find(oldTensor->getTargets().begin(),
|
||||||
|
oldTensor->getTargets().end(),
|
||||||
|
op) != oldTensor->getTargets().end());
|
||||||
|
addConnection(newTensor, op);
|
||||||
|
deleteConnection(oldTensor, op);
|
||||||
|
op->replaceInput(oldTensor, newTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor's "source" and "target" must be in "ops".
|
||||||
|
// tensor has no "source" and no "target" must not exist.
|
||||||
|
// "inputs" or "outputs" of operators must be in "tensors"
|
||||||
|
// "predecessors" and "successors" of an operator of "ops" must be in "ops".
|
||||||
|
bool GraphObj::checkValid() const {
|
||||||
|
for (auto tensor : tensors) {
|
||||||
|
IT_ASSERT(!(tensor->getTargets().size() == 0 &&
|
||||||
|
nullptr == tensor->getSource()));
|
||||||
|
for (auto op : tensor->getTargets()) {
|
||||||
|
IT_ASSERT(std::find(ops.begin(), ops.end(), op) != ops.end());
|
||||||
|
}
|
||||||
|
auto op = tensor->getSource();
|
||||||
|
IT_ASSERT(!(op && std::find(ops.begin(), ops.end(), op) == ops.end()));
|
||||||
|
}
|
||||||
|
for (auto op : ops) {
|
||||||
|
for (auto tensor : op->getInputs()) {
|
||||||
|
IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) !=
|
||||||
|
tensors.end());
|
||||||
|
}
|
||||||
|
for (auto tensor : op->getOutputs()) {
|
||||||
|
IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) !=
|
||||||
|
tensors.end());
|
||||||
|
}
|
||||||
|
for (auto pre : op->getPredecessors()) {
|
||||||
|
IT_ASSERT(std::find(ops.begin(), ops.end(), pre) != ops.end());
|
||||||
|
}
|
||||||
|
for (auto suc : op->getSuccessors()) {
|
||||||
|
IT_ASSERT(std::find(ops.begin(), ops.end(), suc) != ops.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -0,0 +1,465 @@
|
||||||
|
#include "core/graph_match.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
Ref<GraphMatchObj> GraphMatchObj::clone() {
|
||||||
|
auto newObj = make_ref<GraphMatchObj>(pattern);
|
||||||
|
newObj->ops = ops;
|
||||||
|
newObj->opMap = opMap;
|
||||||
|
newObj->opMapRevese = opMapRevese;
|
||||||
|
newObj->tensorMap = tensorMap;
|
||||||
|
return newObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GraphMatchObj::addOp(const Operator &anchorOp, const Operator &patternOp) {
|
||||||
|
ops.emplace(anchorOp);
|
||||||
|
opMap.emplace(anchorOp, patternOp);
|
||||||
|
opMapRevese.emplace(patternOp, anchorOp);
|
||||||
|
recordOutsideTensorMap(patternOp, anchorOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorVec GraphMatchObj::getInputs() const {
|
||||||
|
TensorVec ret;
|
||||||
|
for (auto t : pattern->getInputsFromOutside()) {
|
||||||
|
IT_ASSERT(tensorMap.count(t) > 0);
|
||||||
|
ret.push_back(tensorMap.at(t));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorVec GraphMatchObj::getOutputs() const {
|
||||||
|
TensorVec ret;
|
||||||
|
for (auto t : pattern->getOutputs2Outside()) {
|
||||||
|
IT_ASSERT(tensorMap.count(t) > 0);
|
||||||
|
ret.push_back(tensorMap.at(t));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GraphMatchObj::toString() const {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "MatchGraph operators:\n";
|
||||||
|
for (const auto &op : ops) {
|
||||||
|
vector<UidBaseType> preds, succs;
|
||||||
|
for (auto &o : op->getPredecessors())
|
||||||
|
preds.emplace_back(o->getGuid());
|
||||||
|
for (auto &o : op->getSuccessors())
|
||||||
|
succs.emplace_back(o->getGuid());
|
||||||
|
oss << "OP " << op->getGuid();
|
||||||
|
oss << ", pred " << vecToString(preds);
|
||||||
|
oss << ", succ " << vecToString(succs);
|
||||||
|
oss << ", " << op << "\n";
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the input pattern tensor is from outside,find the
|
||||||
|
// corresponding input anchor tensor,and record.
|
||||||
|
void GraphMatchObj::recordOutsideTensorMap(const Operator &patternOp,
|
||||||
|
const Operator &anchorOp) {
|
||||||
|
for (size_t i = 0; i < patternOp->getInputs().size(); ++i) {
|
||||||
|
if (pattern->isInputFromOutside(patternOp->getInputs(i)))
|
||||||
|
tensorMap.emplace(patternOp->getInputs(i), anchorOp->getInputs(i));
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < patternOp->getOutputs().size(); ++i) {
|
||||||
|
if (pattern->isOutput2Outside(patternOp->getOutput(i)))
|
||||||
|
tensorMap.emplace(patternOp->getOutput(i), anchorOp->getOutput(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SubGraphObj::SubGraphObj(Runtime runtime, const TensorVec &inputs)
|
||||||
|
: GraphObj(runtime), ins(inputs) {
|
||||||
|
for (auto t : ins)
|
||||||
|
tensors.push_back(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<MatchGraph> SubGraphRewriter::findMatch(const SubGraph &pattern) {
|
||||||
|
this->pattern = pattern;
|
||||||
|
vector<MatchGraph> matches;
|
||||||
|
bool firstHead = true, retStatus = true;
|
||||||
|
for (auto input : pattern->getInputsFromOutside()) {
|
||||||
|
auto inputOf = input->getTargets();
|
||||||
|
for (auto opHead : inputOf) {
|
||||||
|
if (std::find(pattern->getOperators().begin(),
|
||||||
|
pattern->getOperators().end(),
|
||||||
|
opHead) == pattern->getOperators().end())
|
||||||
|
continue; // not belongs to pattern
|
||||||
|
if (opHead->getPredecessors().size() > 0) // not a head
|
||||||
|
continue;
|
||||||
|
if (firstHead) {
|
||||||
|
firstHead = false;
|
||||||
|
if (!findMatch(nullptr, nullptr, opHead, matches)) {
|
||||||
|
retStatus = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!findMatch2(nullptr, nullptr, opHead, matches)) {
|
||||||
|
retStatus = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!retStatus)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<MatchGraph> ret;
|
||||||
|
for (auto match : matches) {
|
||||||
|
if (checkMatchValid(match))
|
||||||
|
ret.push_back(match);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::findMatch(const MatchGraph &gLastMatch,
|
||||||
|
const Operator &opLastMatch,
|
||||||
|
const Operator &opPattern,
|
||||||
|
vector<MatchGraph> &gMatch) {
|
||||||
|
OpVec candidates =
|
||||||
|
opLastMatch ? opLastMatch->getSuccessors() : graph->getOperators();
|
||||||
|
OpLists nodesMatch =
|
||||||
|
matchInCandidates(candidates, opPattern, pattern->isHead(opPattern),
|
||||||
|
pattern->isTail(opPattern));
|
||||||
|
|
||||||
|
IT_ASSERT(nodesMatch.size() <= 1 || !opLastMatch);
|
||||||
|
updateMatchedGraph(gLastMatch, nodesMatch, gMatch, opPattern);
|
||||||
|
|
||||||
|
if (nodesMatch.size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// dst is matched, process successors recursively
|
||||||
|
for (auto successorPattern : opPattern->getSuccessors()) {
|
||||||
|
|
||||||
|
bool bRet = false;
|
||||||
|
if (opLastMatch) {
|
||||||
|
IT_ASSERT(nodesMatch.size() == 1);
|
||||||
|
if (gLastMatch->hasMatched(successorPattern))
|
||||||
|
continue;
|
||||||
|
bRet = findMatch(gLastMatch, nodesMatch.front(), successorPattern,
|
||||||
|
gMatch);
|
||||||
|
} else {
|
||||||
|
IT_ASSERT(nodesMatch.size() == gMatch.size());
|
||||||
|
auto tmp1 = gMatch;
|
||||||
|
auto itr1 = nodesMatch.begin();
|
||||||
|
auto itr2 = gMatch.begin();
|
||||||
|
for (; itr1 != nodesMatch.end() && itr2 != gMatch.end(); ++itr2) {
|
||||||
|
if (findMatch(*itr2, *itr1, successorPattern, tmp1)) {
|
||||||
|
bRet = true;
|
||||||
|
++itr1;
|
||||||
|
} else
|
||||||
|
itr1 = nodesMatch.erase(itr1);
|
||||||
|
}
|
||||||
|
gMatch = tmp1;
|
||||||
|
}
|
||||||
|
// not found,return false
|
||||||
|
if (!bRet) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::findMatch2(const MatchGraph &gLastMatch,
|
||||||
|
const Operator &opLastMatch,
|
||||||
|
const Operator &opPattern,
|
||||||
|
vector<MatchGraph> &matches) {
|
||||||
|
vector<MatchGraph> curMatches;
|
||||||
|
for (auto match : matches) {
|
||||||
|
OpVec candidates =
|
||||||
|
opLastMatch ? opLastMatch->getSuccessors() : graph->getOperators();
|
||||||
|
// filter candiates in matches
|
||||||
|
for (auto itr2 = candidates.begin(); itr2 != candidates.end();) {
|
||||||
|
if (match->hasContained(
|
||||||
|
*itr2)) // already belonged to the matched sub graph
|
||||||
|
itr2 = candidates.erase(itr2);
|
||||||
|
else
|
||||||
|
++itr2;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpLists nodesMatch = matchInCandidates(
|
||||||
|
candidates, opPattern, opPattern->getPredecessors().size() == 0,
|
||||||
|
opPattern->getSuccessors().size() == 0);
|
||||||
|
|
||||||
|
// no match nodes found, do not add the match to curMatches, continue
|
||||||
|
if (nodesMatch.size() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto node : nodesMatch) {
|
||||||
|
auto curMatch = match->clone();
|
||||||
|
curMatch->addOp(node, opPattern); // anchor and pattern
|
||||||
|
|
||||||
|
// add to curMatches
|
||||||
|
curMatches.push_back(curMatch);
|
||||||
|
|
||||||
|
// dst is matched, process successors recursively
|
||||||
|
for (auto successorPattern : opPattern->getSuccessors()) {
|
||||||
|
if (match->hasMatched(successorPattern)) // has already matched
|
||||||
|
continue;
|
||||||
|
if (!findMatch(curMatch, node, successorPattern, curMatches)) {
|
||||||
|
// curMatch has been removed from curMatches in
|
||||||
|
// "findMatch",so just break
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
matches = curMatches;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpLists SubGraphRewriter::matchInCandidates(const OpVec &ops,
|
||||||
|
const Operator &opPattern,
|
||||||
|
bool isHead, bool isTail) {
|
||||||
|
OpLists ret;
|
||||||
|
for (auto op : ops) {
|
||||||
|
if (MatchNode(opPattern, op, isHead, isTail))
|
||||||
|
ret.push_back(op);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::MatchNode(const Operator &a, const Operator &b,
|
||||||
|
bool isHead, bool isTail) const {
|
||||||
|
if (a->getOpType() != b->getOpType())
|
||||||
|
return false;
|
||||||
|
if (a->hash() != b->hash())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!isHead)
|
||||||
|
if (a->getPredecessors().size() != b->getPredecessors().size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!isTail)
|
||||||
|
if (a->getSuccessors().size() != b->getSuccessors().size())
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
void SubGraphRewriter::updateMatchedGraph(const MatchGraph &gLastMatch,
|
||||||
|
OpLists &opMatch,
|
||||||
|
vector<MatchGraph> &gMatch,
|
||||||
|
Operator opPattern) {
|
||||||
|
if (opMatch.size() == 0) {
|
||||||
|
if (nullptr != gLastMatch) {
|
||||||
|
auto pos = std::find(gMatch.begin(), gMatch.end(), gLastMatch);
|
||||||
|
IT_ASSERT(pos != gMatch.end());
|
||||||
|
gMatch.erase(pos);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// anchor is a head
|
||||||
|
if (nullptr == gLastMatch) {
|
||||||
|
for (auto op : opMatch) {
|
||||||
|
auto match = make_ref<GraphMatchObj>(pattern);
|
||||||
|
match->addOp(op, opPattern);
|
||||||
|
gMatch.push_back(match);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
IT_ASSERT(opMatch.size() == 1);
|
||||||
|
gLastMatch->addOp(opMatch.front(), opPattern);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::checkOverlapsWithPreviousMatch(
|
||||||
|
const MatchGraph &match,
|
||||||
|
const std::unordered_set<Operator> &nodesToDelete) const {
|
||||||
|
for (auto op : match->getOps()) {
|
||||||
|
if (nodesToDelete.count(op) > 0)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::checkMatchValid(const MatchGraph &match) const {
|
||||||
|
for (auto t : pattern->getInputsFromOutside()) {
|
||||||
|
auto tAnchor = match->getAnchorByPattern(t);
|
||||||
|
// the corrresponding precessor must not belong to the match
|
||||||
|
auto preOpAnchor = tAnchor->getSource();
|
||||||
|
if (preOpAnchor && match->hasContained(preOpAnchor)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check connections
|
||||||
|
for (auto opPattern : pattern->getOperators()) {
|
||||||
|
auto opAnchor = match->getAnchorByPattern(opPattern);
|
||||||
|
for (auto prePattern : opPattern->getPredecessors()) {
|
||||||
|
auto preAnchor = match->getAnchorByPattern(prePattern);
|
||||||
|
auto ops = opAnchor->getPredecessors();
|
||||||
|
if (std::find(ops.begin(), ops.end(), preAnchor) == ops.end())
|
||||||
|
return false;
|
||||||
|
ops = preAnchor->getSuccessors();
|
||||||
|
if (std::find(ops.begin(), ops.end(), opAnchor) == ops.end())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// replace all sub graphs which matched subA with subB in g
|
||||||
|
void SubGraphRewriter::replaceSubGraph(const SubGraph &pattern,
|
||||||
|
const SubGraph &replacement) {
|
||||||
|
IT_ASSERT(checkReplacement(pattern, replacement));
|
||||||
|
this->pattern = pattern;
|
||||||
|
|
||||||
|
// find matches in graph.
|
||||||
|
auto matches = findMatch(pattern);
|
||||||
|
|
||||||
|
std::unordered_set<Operator> nodesToDelete;
|
||||||
|
map<Tensor, Tensor> replaceMap;
|
||||||
|
map<Tensor, Tensor> replaceMapReverse;
|
||||||
|
for (auto match : matches) {
|
||||||
|
// matches may overlap with eachother. if some operator has been in
|
||||||
|
// another folded match,we must skip this one
|
||||||
|
if (!checkOverlapsWithPreviousMatch(match, nodesToDelete))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto inputs = match->getInputs();
|
||||||
|
for (auto &input : inputs) {
|
||||||
|
if (replaceMap.count(input) > 0)
|
||||||
|
input = replaceMap[input];
|
||||||
|
}
|
||||||
|
auto outputs = match->getOutputs();
|
||||||
|
|
||||||
|
// first, remove old successors for input
|
||||||
|
for (auto input : inputs) {
|
||||||
|
for (auto op : input->getTargets()) {
|
||||||
|
if (match->hasContained(op)) {
|
||||||
|
graph->deleteConnection(input, op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// second, insert replacement sub graph to graph.
|
||||||
|
auto newOutputs = addSubGraph(replacement, inputs);
|
||||||
|
|
||||||
|
// check replaced outputs and record
|
||||||
|
IT_ASSERT(outputs.size() == newOutputs.size());
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
IT_ASSERT(isReplacable(outputs[i], newOutputs[i]));
|
||||||
|
replaceMap.emplace(outputs[i], newOutputs[i]);
|
||||||
|
replaceMapReverse.emplace(newOutputs[i], outputs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// third, change connections for new output
|
||||||
|
for (auto output : outputs) {
|
||||||
|
auto successors = output->getTargets();
|
||||||
|
for (auto successor : successors) {
|
||||||
|
auto newOutput = replaceMap[output];
|
||||||
|
graph->replaceConnection(output, newOutput, successor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record ops need to delete
|
||||||
|
for (auto op : match->getOps())
|
||||||
|
nodesToDelete.insert(op);
|
||||||
|
|
||||||
|
// remove match from graph
|
||||||
|
for (auto op : match->getOps()) {
|
||||||
|
for (auto tensor : op->getInputs()) {
|
||||||
|
if (replaceMapReverse.count(tensor) > 0)
|
||||||
|
tensor = replaceMapReverse[tensor];
|
||||||
|
if (std::find(inputs.begin(), inputs.end(), tensor) ==
|
||||||
|
inputs.end()) {
|
||||||
|
graph->removeTensor(tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto tensor : op->getOutputs()) {
|
||||||
|
graph->removeTensor(tensor);
|
||||||
|
}
|
||||||
|
graph->removeOperator(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
IT_ASSERT(graph->checkValid());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// "inputs" must be tensors in original graph
|
||||||
|
TensorVec SubGraphRewriter::addSubGraph(const SubGraph &g,
|
||||||
|
const TensorVec &inputs) {
|
||||||
|
// check inputs
|
||||||
|
for (auto input : inputs) {
|
||||||
|
auto tensors = graph->getTensors();
|
||||||
|
IT_ASSERT(std::find(tensors.begin(), tensors.end(), input) !=
|
||||||
|
tensors.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// check compatible with sub graph
|
||||||
|
auto ins = g->getInputsFromOutside();
|
||||||
|
IT_ASSERT(checkReplacement(ins, inputs));
|
||||||
|
|
||||||
|
std::map<Tensor, Tensor> tensorMap;
|
||||||
|
for (size_t i = 0; i < ins.size(); ++i) {
|
||||||
|
tensorMap.emplace(ins[i], inputs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto t : g->getTensors()) {
|
||||||
|
if (tensorMap.find(t) == tensorMap.end()) {
|
||||||
|
auto tClone = graph->addTensor(t->getDims(), t->getDType());
|
||||||
|
tensorMap.emplace(t, tClone);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto op : g->getOperators()) {
|
||||||
|
TensorVec inputs, outputs;
|
||||||
|
for (auto t : op->getInputs()) {
|
||||||
|
inputs.push_back(tensorMap.at(t));
|
||||||
|
}
|
||||||
|
for (auto t : op->getOutputs()) {
|
||||||
|
outputs.push_back(tensorMap.at(t));
|
||||||
|
}
|
||||||
|
graph->cloneOperator(op, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorVec out;
|
||||||
|
for (auto t : g->getOutputs2Outside()) {
|
||||||
|
out.push_back(tensorMap[t]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SubGraphRewriter::removeSubGraph(MatchGraph match) {
|
||||||
|
TensorVec inputs = match->getInputs();
|
||||||
|
|
||||||
|
for (auto op : match->getOps()) {
|
||||||
|
for (auto tensor : op->getInputs()) {
|
||||||
|
if (std::find(inputs.begin(), inputs.end(), tensor) ==
|
||||||
|
inputs.end()) {
|
||||||
|
graph->removeTensor(tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto tensor : op->getOutputs()) {
|
||||||
|
graph->removeTensor(tensor);
|
||||||
|
}
|
||||||
|
graph->removeOperator(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputs and outputs must be appointed.
|
||||||
|
bool SubGraphRewriter::checkReplacement(const SubGraph &pattern,
|
||||||
|
const SubGraph &other) const {
|
||||||
|
return checkReplacement(pattern->getInputsFromOutside(),
|
||||||
|
other->getInputsFromOutside()) &&
|
||||||
|
checkReplacement(pattern->getOutputs2Outside(),
|
||||||
|
other->getOutputs2Outside()) &&
|
||||||
|
pattern->getInputsFromOutside().size() != 0 &&
|
||||||
|
pattern->getOutputs2Outside().size() != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::checkReplacement(const TensorVec &left,
|
||||||
|
const TensorVec &right) const {
|
||||||
|
if (left.size() != right.size())
|
||||||
|
return false;
|
||||||
|
for (size_t i = 0; i < left.size(); ++i) {
|
||||||
|
if (!isReplacable(left[i], right[i]))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SubGraphRewriter::isReplacable(const Tensor &l, const Tensor &r) const {
|
||||||
|
return (l->getDType() == r->getDType() && l->getDims() == r->getDims());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -37,6 +37,32 @@ bool OperatorObj::isMemBoundOp() const {
|
||||||
type == OpType::Transpose;
|
type == OpType::Transpose;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OperatorObj::removePredecessors(const Operator &op) {
|
||||||
|
for (auto it = predecessors.begin(); it != predecessors.end();) {
|
||||||
|
if (it->lock() == op)
|
||||||
|
it = predecessors.erase(it);
|
||||||
|
else
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OperatorObj::removeSuccessors(const Operator &op) {
|
||||||
|
for (auto it = successors.begin(); it != successors.end();) {
|
||||||
|
if (it->lock() == op)
|
||||||
|
it = successors.erase(it);
|
||||||
|
else
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OperatorObj::replaceInput(Tensor t1, Tensor t2) {
|
||||||
|
for (auto itr = inputs.begin(); itr != inputs.end(); ++itr) {
|
||||||
|
if (*itr == t1) {
|
||||||
|
*itr = t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
OpPerfKey OperatorObj::getOpPerfKey() const {
|
OpPerfKey OperatorObj::getOpPerfKey() const {
|
||||||
auto workloadVector = getWorkloadVector();
|
auto workloadVector = getWorkloadVector();
|
||||||
// Calculate hash of workload, i.e. hash with shape. This is different from
|
// Calculate hash of workload, i.e. hash with shape. This is different from
|
||||||
|
|
|
@ -41,7 +41,7 @@ vector<int> ConvBaseObj::getWorkloadVector() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> ConvBaseObj::getOpAttrVector() const {
|
vector<int> ConvBaseObj::getOpAttrVector() const {
|
||||||
IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
|
// IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
|
||||||
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
|
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,6 @@ vector<int> PoolingObj::getWorkloadVector() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<int> PoolingObj::getOpAttrVector() const {
|
vector<int> PoolingObj::getOpAttrVector() const {
|
||||||
IT_TODO_HALT();
|
|
||||||
return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw};
|
return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,422 @@
|
||||||
|
#include "core/blob.h"
|
||||||
|
#include "core/graph_match.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/concat.h"
|
||||||
|
#include "operators/conv.h"
|
||||||
|
#include "operators/element_wise.h"
|
||||||
|
#include "operators/extend.h"
|
||||||
|
#include "operators/pad.h"
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
|
#include "operators/split.h"
|
||||||
|
#include "operators/unary.h"
|
||||||
|
#include "test.h"
|
||||||
|
namespace infini {
|
||||||
|
// hrnet48 head match conv-relu
|
||||||
|
TEST(SubGraphRewriter, subGraphMatch1) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i0 = g->addTensor({1, 3, 244, 244}, DataType::UInt32);
|
||||||
|
Tensor w0 = g->addTensor({64, 3, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv = g->addOp<ConvObj>(i0, w0, nullptr);
|
||||||
|
auto relu = g->addOp<ReluObj>(conv->getOutput(), nullptr);
|
||||||
|
|
||||||
|
auto w1 = g->addTensor({64, 64, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv1 = g->addOp<ConvObj>(relu->getOutput(0), w1, nullptr);
|
||||||
|
auto relu1 = g->addOp<ReluObj>(conv1->getOutput(), nullptr);
|
||||||
|
|
||||||
|
auto w2 = g->addTensor({64, 64, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv2 = g->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
|
auto relu2 = g->addOp<ReluObj>(conv2->getOutput(), nullptr);
|
||||||
|
|
||||||
|
auto w3 = g->addTensor({256, 64, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv3 = g->addOp<ConvObj>(relu1->getOutput(0), w3, nullptr);
|
||||||
|
|
||||||
|
auto w4 = g->addTensor({64, 64, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv4 = g->addOp<ConvObj>(relu2->getOutput(0), w4, nullptr);
|
||||||
|
auto relu4 = g->addOp<ReluObj>(conv4->getOutput(), nullptr);
|
||||||
|
|
||||||
|
Tensor si0 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 64, 112, 112}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subG = make_ref<SubGraphObj>(runtime, TensorVec{si0});
|
||||||
|
Tensor sw0 = subG->addTensor({64, 64, 3, 3}, DataType::UInt32);
|
||||||
|
auto sconv0 = subG->addOp<ConvObj>(si0, sw0, nullptr);
|
||||||
|
auto srelu0 = subG->addOp<ReluObj>(sconv0->getOutput(), nullptr);
|
||||||
|
subG->setOutputs(srelu0->getOutputs());
|
||||||
|
|
||||||
|
SubGraphRewriter v(g);
|
||||||
|
vector<MatchGraph> subgs = v.findMatch(subG);
|
||||||
|
|
||||||
|
EXPECT_TRUE(subgs.size() == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatchGraph, single_input) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
// subG0
|
||||||
|
Tensor si0 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 96, 28, 28}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subG = make_ref<SubGraphObj>(runtime, TensorVec{si0});
|
||||||
|
{
|
||||||
|
auto srelu0 = subG->addOp<ReluObj>(si0, nullptr);
|
||||||
|
auto sw0 = subG->addTensor({96, 96, 3, 3}, DataType::UInt32);
|
||||||
|
auto sconv0 = subG->addOp<ConvObj>(srelu0->getOutput(0), sw0, nullptr);
|
||||||
|
auto srelu1 = subG->addOp<ReluObj>(sconv0->getOutput(), nullptr);
|
||||||
|
auto sw1 = subG->addTensor({96, 96, 3, 3}, DataType::UInt32);
|
||||||
|
auto sconv1 = subG->addOp<ConvObj>(srelu1->getOutput(0), sw1, nullptr);
|
||||||
|
auto sadd0 = subG->addOp<AddObj>(sconv1->getOutput(0),
|
||||||
|
srelu0->getOutput(0), nullptr);
|
||||||
|
subG->setOutputs({sadd0->getOutput()});
|
||||||
|
}
|
||||||
|
// subG1
|
||||||
|
Tensor si00 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 48, 56, 56}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subG1 = make_ref<SubGraphObj>(runtime, TensorVec{si00});
|
||||||
|
{
|
||||||
|
auto srelu0 = subG1->addOp<ReluObj>(si00, nullptr);
|
||||||
|
auto sw0 = subG1->addTensor({48, 48, 3, 3}, DataType::UInt32);
|
||||||
|
auto sconv0 = subG1->addOp<ConvObj>(srelu0->getOutput(0), sw0, nullptr);
|
||||||
|
auto srelu1 = subG1->addOp<ReluObj>(sconv0->getOutput(), nullptr);
|
||||||
|
auto sw1 = subG1->addTensor({48, 48, 3, 3}, DataType::UInt32);
|
||||||
|
auto sconv1 = subG1->addOp<ConvObj>(srelu1->getOutput(0), sw1, nullptr);
|
||||||
|
auto sadd0 = subG1->addOp<AddObj>(sconv1->getOutput(0),
|
||||||
|
srelu0->getOutput(0), nullptr);
|
||||||
|
subG1->setOutputs({sadd0->getOutput()});
|
||||||
|
}
|
||||||
|
|
||||||
|
// graph
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
SubGraphRewriter v(g);
|
||||||
|
|
||||||
|
Tensor i0 = g->addTensor({1, 256, 56, 56}, DataType::UInt32);
|
||||||
|
auto relu0 = g->addOp<ReluObj>(i0, nullptr);
|
||||||
|
|
||||||
|
Tensor w0 = g->addTensor({96, 256, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv0 =
|
||||||
|
g->addOp<ConvObj>(relu0->getOutput(0), w0, nullptr, 1, 1, 2, 2);
|
||||||
|
|
||||||
|
auto o0 = v.addSubGraph(subG, {conv0->getOutput(0)});
|
||||||
|
auto o1 = v.addSubGraph(subG, o0);
|
||||||
|
auto o2 = v.addSubGraph(subG, o1);
|
||||||
|
auto o3 = v.addSubGraph(subG, o2);
|
||||||
|
auto relu4 = g->addOp<ReluObj>(o3[0], nullptr);
|
||||||
|
|
||||||
|
Tensor w10 = g->addTensor({48, 256, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv10 = g->addOp<ConvObj>(relu0->getOutput(0), w10, nullptr);
|
||||||
|
auto o10 = v.addSubGraph(subG1, {conv10->getOutput(0)});
|
||||||
|
auto o11 = v.addSubGraph(subG1, o10);
|
||||||
|
auto o12 = v.addSubGraph(subG1, o11);
|
||||||
|
auto o13 = v.addSubGraph(subG1, o12);
|
||||||
|
auto relu10 = g->addOp<ReluObj>(o13[0], nullptr);
|
||||||
|
Tensor w1 = g->addTensor({96, 48, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv1 =
|
||||||
|
g->addOp<ConvObj>(relu10->getOutput(), w1, nullptr, 1, 1, 2, 2);
|
||||||
|
auto add1 =
|
||||||
|
g->addOp<AddObj>(relu4->getOutput(), conv1->getOutput(), nullptr);
|
||||||
|
|
||||||
|
auto o4 = v.addSubGraph(subG, TensorVec{add1->getOutput(0)});
|
||||||
|
|
||||||
|
EXPECT_EQ(g->getOperators().size(), 52);
|
||||||
|
vector<MatchGraph> subgs = v.findMatch(subG);
|
||||||
|
EXPECT_TRUE(subgs.size() == 5);
|
||||||
|
|
||||||
|
vector<MatchGraph> subgs1 = v.findMatch(subG1);
|
||||||
|
EXPECT_TRUE(subgs1.size() == 4);
|
||||||
|
|
||||||
|
// test replace
|
||||||
|
Tensor sii0 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 96, 28, 28}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subG2 = make_ref<SubGraphObj>(runtime, TensorVec{sii0});
|
||||||
|
{
|
||||||
|
auto srelu0 = subG2->addOp<ReluObj>(sii0, nullptr);
|
||||||
|
auto sw0 = subG2->addTensor({96, 96, 3, 3}, DataType::UInt32);
|
||||||
|
auto sconv0 = subG2->addOp<ConvObj>(srelu0->getOutput(0), sw0, nullptr);
|
||||||
|
subG2->setOutputs(sconv0->getOutputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
v.replaceSubGraph(subG, subG2);
|
||||||
|
EXPECT_EQ(g->getOperators().size(), 37);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatchGraph, multi_input) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
// subG0
|
||||||
|
Tensor i0 =
|
||||||
|
make_ref<TensorObj>(Shape{3, 4, 5, 2}, DataType::UInt32, runtime);
|
||||||
|
Tensor i1 = make_ref<TensorObj>(Shape{3, 4, 5}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subG = make_ref<SubGraphObj>(runtime, TensorVec{i0, i1});
|
||||||
|
{
|
||||||
|
auto reduce0 =
|
||||||
|
subG->addOp<ReduceMeanObj>(i0, nullptr, vector<int>{3}, false);
|
||||||
|
auto sub0 = subG->addOp<SubObj>(reduce0->getOutput(0), i1, nullptr);
|
||||||
|
subG->setOutputs(sub0->getOutputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
SubGraph replaceG = make_ref<SubGraphObj>(
|
||||||
|
runtime, TensorVec{i0->clone(runtime), i1->clone(runtime)});
|
||||||
|
{
|
||||||
|
auto reduce0 =
|
||||||
|
replaceG->addOp<ReduceMeanObj>(replaceG->getInputsFromOutside()[0],
|
||||||
|
nullptr, vector<int>{3}, false);
|
||||||
|
auto sub0 = replaceG->addOp<AddObj>(reduce0->getOutput(0),
|
||||||
|
replaceG->getInputsFromOutside()[1],
|
||||||
|
nullptr);
|
||||||
|
replaceG->setOutputs(sub0->getOutputs());
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
SubGraphRewriter v(g);
|
||||||
|
{
|
||||||
|
Tensor i0 = g->addTensor({3, 4, 5, 2}, DataType::UInt32);
|
||||||
|
Tensor i1 = g->addTensor({3, 4, 5, 2}, DataType::UInt32);
|
||||||
|
auto add0 = g->addOp<AddObj>(i0, i1, nullptr);
|
||||||
|
auto relu0 = g->addOp<ReluObj>(add0->getOutput(), nullptr);
|
||||||
|
auto reduce0 = g->addOp<ReduceMeanObj>(relu0->getOutput(), nullptr,
|
||||||
|
vector<int>{3}, false);
|
||||||
|
auto o0 =
|
||||||
|
v.addSubGraph(subG, {add0->getOutput(), reduce0->getOutput()});
|
||||||
|
|
||||||
|
Tensor i2 = g->addTensor({3, 4, 5}, DataType::UInt32);
|
||||||
|
auto pow0 = g->addOp<PowObj>(o0[0], i2, nullptr);
|
||||||
|
|
||||||
|
Tensor i3 = g->addTensor({3, 4, 5, 2}, DataType::UInt32);
|
||||||
|
auto reduce1 =
|
||||||
|
g->addOp<ReduceMeanObj>(i3, nullptr, vector<int>{3}, false);
|
||||||
|
auto sub0 = g->addOp<SubObj>(reduce1->getOutput(0), pow0->getOutput(0),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
auto matches = v.findMatch(subG);
|
||||||
|
EXPECT_EQ(2, matches.size());
|
||||||
|
|
||||||
|
auto div0 = g->addOp<DivObj>(reduce1->getOutput(0), i2, nullptr);
|
||||||
|
auto add1 =
|
||||||
|
g->addOp<AddObj>(sub0->getOutput(), div0->getOutput(), nullptr);
|
||||||
|
matches = v.findMatch(subG);
|
||||||
|
EXPECT_EQ(1, matches.size());
|
||||||
|
|
||||||
|
// two matched subgraphs overlaped,so only replaced one sub graph
|
||||||
|
v.replaceSubGraph(subG, replaceG);
|
||||||
|
EXPECT_EQ(1, v.findMatch(replaceG).size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MatchGraph, multi_output) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
// subg0
|
||||||
|
Tensor i =
|
||||||
|
make_ref<TensorObj>(Shape{1, 192, 71, 71}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subg0 = make_ref<SubGraphObj>(runtime, TensorVec{i});
|
||||||
|
{
|
||||||
|
auto maxpool =
|
||||||
|
subg0->addOp<MaxPoolObj>(i, nullptr, 3, 3, 0, 0, 0, 0, 2, 2);
|
||||||
|
Tensor w0 = subg0->addTensor(Shape{64, 192, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv0 = subg0->addOp<ConvObj>(maxpool->getOutput(0), w0, nullptr);
|
||||||
|
auto relu0 = subg0->addOp<ReluObj>(conv0->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
auto pad = subg0->addOp<PadObj>(maxpool->getOutput(0), nullptr,
|
||||||
|
vector<int>{0, 0, 1, 1, 0, 0, 1, 1},
|
||||||
|
std::nullopt);
|
||||||
|
auto avgpool = subg0->addOp<AvgPoolObj>(pad->getOutput(0), nullptr, 3,
|
||||||
|
3, 0, 0, 0, 0, 1, 1);
|
||||||
|
subg0->setOutputs(
|
||||||
|
TensorVec{relu0->getOutput(0), avgpool->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
SubGraph subg1 =
|
||||||
|
make_ref<SubGraphObj>(runtime, TensorVec{i->clone(runtime)});
|
||||||
|
{
|
||||||
|
auto avgpool = subg1->addOp<AvgPoolObj>(
|
||||||
|
subg1->getInputsFromOutside()[0], nullptr, 3, 3, 0, 0, 0, 0, 2, 2);
|
||||||
|
|
||||||
|
auto relu0 = subg1->addOp<ReluObj>(avgpool->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
auto split0 =
|
||||||
|
subg1->addOp<SplitObj>(avgpool->getOutput(0), std::nullopt, 1, 3);
|
||||||
|
subg1->setOutputs(TensorVec{split0->getOutput(1), relu0->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
SubGraphRewriter v(g);
|
||||||
|
{
|
||||||
|
auto input = g->cloneTensor(i);
|
||||||
|
auto outs = v.addSubGraph(subg0, {input});
|
||||||
|
EXPECT_EQ(2, outs.size());
|
||||||
|
Tensor w0 = g->addTensor(Shape{96, 64, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv0 = g->addOp<ConvObj>(outs[0], w0, nullptr, 1, 1);
|
||||||
|
auto relu0 = g->addOp<ReluObj>(conv0->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
Tensor w1 = g->addTensor(Shape{96, 96, 3, 3}, DataType::UInt32);
|
||||||
|
auto conv1 = g->addOp<ConvObj>(relu0->getOutput(), w1, nullptr, 1, 1);
|
||||||
|
auto relu1 = g->addOp<ReluObj>(conv1->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
Tensor w2 = g->addTensor(Shape{32, 192, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv2 = g->addOp<ConvObj>(outs[1], w2, nullptr);
|
||||||
|
auto relu2 = g->addOp<ReluObj>(conv2->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
Tensor i0 = g->addTensor(Shape{1, 64, 35, 35}, DataType::UInt32);
|
||||||
|
Tensor i1 = g->addTensor(Shape{1, 64, 35, 35}, DataType::UInt32);
|
||||||
|
auto concat = g->addOp<ConcatObj>(
|
||||||
|
TensorVec{i0, i1, relu1->getOutput(), relu2->getOutput()}, nullptr,
|
||||||
|
1);
|
||||||
|
auto o = concat->getOutput();
|
||||||
|
EXPECT_TRUE((o->getDims() == Shape{1, 256, 35, 35}));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matches = v.findMatch(subg0);
|
||||||
|
EXPECT_EQ(1, matches.size());
|
||||||
|
|
||||||
|
v.replaceSubGraph(subg0, subg1);
|
||||||
|
auto matches2 = v.findMatch(subg1);
|
||||||
|
EXPECT_EQ(1, matches2.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// gcn
|
||||||
|
TEST(MatchGraph, multi_input_output) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
// subg0
|
||||||
|
Tensor i0 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 64, 112, 112}, DataType::UInt32, runtime);
|
||||||
|
Tensor i1 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 64, 56, 56}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subg0 = make_ref<SubGraphObj>(runtime, TensorVec{i0, i1});
|
||||||
|
{
|
||||||
|
auto slice = subg0->addOp<SliceObj>(i0, nullptr, vector<int>{0, 0},
|
||||||
|
vector<int>{55, 55},
|
||||||
|
vector<int>{2, 3}, std::nullopt);
|
||||||
|
auto relu0 = subg0->addOp<ReluObj>(slice->getOutput(0), nullptr);
|
||||||
|
Tensor w0 = subg0->addTensor(Shape{256, 64, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv0 = subg0->addOp<ConvObj>(relu0->getOutput(0), w0, nullptr);
|
||||||
|
|
||||||
|
auto conv1 = subg0->addOp<ConvObj>(i1, w0, nullptr);
|
||||||
|
auto add = subg0->addOp<AddObj>(conv0->getOutput(0),
|
||||||
|
conv1->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
auto relu1 = subg0->addOp<ReluObj>(add->getOutput(0), nullptr);
|
||||||
|
Tensor w2 = subg0->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv2 = subg0->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
|
auto maxpool = subg0->addOp<MaxPoolObj>(relu1->getOutput(0), nullptr, 3,
|
||||||
|
3, 0, 0, 0, 0, 2, 2);
|
||||||
|
subg0->setOutputs(
|
||||||
|
TensorVec{conv2->getOutput(0), maxpool->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
SubGraph subg1 = make_ref<SubGraphObj>(runtime, TensorVec{i1, i0});
|
||||||
|
{
|
||||||
|
auto slice = subg1->addOp<SliceObj>(i0, nullptr, vector<int>{0, 0},
|
||||||
|
vector<int>{55, 55},
|
||||||
|
vector<int>{2, 3}, std::nullopt);
|
||||||
|
auto relu0 = subg1->addOp<ReluObj>(slice->getOutput(0), nullptr);
|
||||||
|
Tensor w0 = subg1->addTensor(Shape{256, 64, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv0 = subg1->addOp<ConvObj>(relu0->getOutput(0), w0, nullptr);
|
||||||
|
|
||||||
|
auto conv1 = subg1->addOp<ConvObj>(i1, w0, nullptr);
|
||||||
|
auto add = subg1->addOp<AddObj>(conv1->getOutput(0),
|
||||||
|
conv0->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
auto relu1 = subg1->addOp<ReluObj>(add->getOutput(0), nullptr);
|
||||||
|
Tensor w2 = subg1->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv2 = subg1->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
|
auto maxpool = subg1->addOp<MaxPoolObj>(relu1->getOutput(0), nullptr, 3,
|
||||||
|
3, 0, 0, 0, 0, 2, 2);
|
||||||
|
subg1->setOutputs(
|
||||||
|
TensorVec{maxpool->getOutput(0), conv2->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
SubGraph subg2 = make_ref<SubGraphObj>(runtime, TensorVec{i0, i1});
|
||||||
|
{
|
||||||
|
auto extend = subg2->addOp<ExtendObj>(i0, nullptr, 1, 3);
|
||||||
|
|
||||||
|
auto slice = subg2->addOp<SliceObj>(
|
||||||
|
extend->getOutput(0), nullptr, vector<int>{0, 0},
|
||||||
|
vector<int>{55, 55}, vector<int>{2, 3}, std::nullopt);
|
||||||
|
|
||||||
|
auto extend1 = subg2->addOp<ExtendObj>(i1, nullptr, 1, 3);
|
||||||
|
auto add = subg2->addOp<AddObj>(extend1->getOutput(0),
|
||||||
|
slice->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
auto relu1 = subg2->addOp<ReluObj>(add->getOutput(0), nullptr);
|
||||||
|
Tensor w2 = subg2->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32);
|
||||||
|
auto conv2 = subg2->addOp<ConvObj>(relu1->getOutput(0), w2, nullptr);
|
||||||
|
auto avgpool = subg2->addOp<AvgPoolObj>(relu1->getOutput(0), nullptr, 3,
|
||||||
|
3, 0, 0, 0, 0, 2, 2);
|
||||||
|
subg2->setOutputs(
|
||||||
|
TensorVec{conv2->getOutput(0), avgpool->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
SubGraphRewriter v(g);
|
||||||
|
{
|
||||||
|
auto i = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32);
|
||||||
|
auto relu = g->addOp<ReluObj>(i, nullptr);
|
||||||
|
auto maxPool = g->addOp<MaxPoolObj>(relu->getOutput(0), nullptr, 3, 3,
|
||||||
|
0, 0, 1, 1, 2, 2);
|
||||||
|
auto out0 =
|
||||||
|
v.addSubGraph(subg0, {relu->getOutput(0), maxPool->getOutput(0)});
|
||||||
|
auto out1 =
|
||||||
|
v.addSubGraph(subg1, {maxPool->getOutput(0), relu->getOutput(0)});
|
||||||
|
EXPECT_EQ(2, out0.size());
|
||||||
|
EXPECT_EQ(2, out1.size());
|
||||||
|
auto div = g->addOp<DivObj>(out0[0], out1[1], nullptr);
|
||||||
|
auto sub = g->addOp<SubObj>(out0[1], out1[0], nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(2, v.findMatch(subg0).size());
|
||||||
|
EXPECT_EQ(2, v.findMatch(subg1).size());
|
||||||
|
v.replaceSubGraph(subg0, subg2);
|
||||||
|
EXPECT_EQ(v.findMatch(subg2).size(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* One Node having two or more successors is not supported yet.
|
||||||
|
TEST(MatchGraph, same_successor) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
// subg0
|
||||||
|
Tensor i0 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 64, 112, 112}, DataType::UInt32, runtime);
|
||||||
|
Tensor i1 =
|
||||||
|
make_ref<TensorObj>(Shape{1, 64, 112, 112}, DataType::UInt32, runtime);
|
||||||
|
SubGraph subg0 = make_ref<SubGraphObj>(runtime, TensorVec{i0, i1});
|
||||||
|
{
|
||||||
|
auto add0 = subg0->addOp<AddObj>(i0, i1, nullptr);
|
||||||
|
auto add1 = subg0->addOp<AddObj>(add0->getOutput(0), i1, nullptr);
|
||||||
|
auto add2 = subg0->addOp<AddObj>(add0->getOutput(0), i1, nullptr);
|
||||||
|
|
||||||
|
auto mul = subg0->addOp<MulObj>(add1->getOutput(0), i1, nullptr);
|
||||||
|
auto div = subg0->addOp<DivObj>(add2->getOutput(0), i1, nullptr);
|
||||||
|
|
||||||
|
auto sub =
|
||||||
|
subg0->addOp<SubObj>(mul->getOutput(0), div->getOutput(0), nullptr);
|
||||||
|
|
||||||
|
subg0->setOutputs(TensorVec{sub->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// pattern
|
||||||
|
SubGraph pattern1 = make_ref<SubGraphObj>(runtime, TensorVec{i0, i1});
|
||||||
|
{
|
||||||
|
auto add0 = pattern1->addOp<AddObj>(i0, i1, nullptr);
|
||||||
|
auto add1 = pattern1->addOp<AddObj>(add0->getOutput(0), i1, nullptr);
|
||||||
|
auto div = pattern1->addOp<DivObj>(add1->getOutput(0), i1, nullptr);
|
||||||
|
pattern1->setOutputs(TensorVec{add0->getOutput(0), div->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// pattern
|
||||||
|
SubGraph pattern2 = make_ref<SubGraphObj>(runtime, TensorVec{i0, i1});
|
||||||
|
{
|
||||||
|
auto add0 = pattern2->addOp<AddObj>(i0, i1, nullptr);
|
||||||
|
auto add1 = pattern2->addOp<AddObj>(add0->getOutput(0), i1, nullptr);
|
||||||
|
pattern2->setOutputs(TensorVec{add0->getOutput(0), add1->getOutput(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
SubGraphRewriter v(g);
|
||||||
|
{
|
||||||
|
i0 = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32);
|
||||||
|
i1 = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32);
|
||||||
|
auto out0 = v.addSubGraph(subg0, {i0, i1});
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(1, v.findMatch(pattern1).size());
|
||||||
|
EXPECT_EQ(2, v.findMatch(pattern2).size());
|
||||||
|
v.replaceSubGraph(pattern2, pattern1);
|
||||||
|
EXPECT_EQ(v.findMatch(pattern2).size(), 2);
|
||||||
|
}*/
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue