From 43d47983236d1f92e2a11c0ecdb6d9577f180b76 Mon Sep 17 00:00:00 2001 From: wendy12022 Date: Mon, 17 Apr 2023 13:09:07 +0800 Subject: [PATCH] ADD: sub graph replacement. (#56) reconfig: connections among op and tensor now is managered by GraphObj . add some comments merge from master merge from master ADD: sub graph replacement reconfig inputs of op resize, due to the check of operator inputs. ResizeObj::clone clang format fix some and add test for multi-output. replacement support multi-inputs and multi-outputs. add clone for all operators add replaceSubGraph addSubGraph remove extra code add more test remove extra print Co-authored-by: Haojie Wang --- include/core/graph.h | 23 ++ include/core/graph_match.h | 108 ++++++++ include/core/operator.h | 12 +- include/core/runtime.h | 1 + include/core/tensor_base.h | 19 +- include/operators/batch_norm.h | 1 - include/operators/split.h | 2 +- src/core/graph.cc | 67 ++++- src/core/graph_match.cc | 465 ++++++++++++++++++++++++++++++++ src/core/operator.cc | 26 ++ src/operators/conv.cc | 2 +- src/operators/pooling.cc | 1 - test/core/test_graph_replace.cc | 422 +++++++++++++++++++++++++++++ 13 files changed, 1138 insertions(+), 11 deletions(-) create mode 100644 include/core/graph_match.h create mode 100644 src/core/graph_match.cc create mode 100644 test/core/test_graph_replace.cc diff --git a/include/core/graph.h b/include/core/graph.h index 4e654caa..cbeceac1 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -22,6 +22,27 @@ class GraphObj : public Object { Tensor cloneTensor(const Tensor &tensor) { return addTensor(tensor->clone(runtime)); } + void removeOperator(Operator op) { + auto it = std::find(ops.begin(), ops.end(), op); + if (it != ops.end()) + ops.erase(it); + } + + void removeTensor(Tensor tensor) { + auto it = std::find(tensors.begin(), tensors.end(), tensor); + if (it != tensors.end()) + tensors.erase(it); + } + + void deleteConnection(Tensor tensor, Operator op); + void addConnection(Tensor tensor, Operator op); + void replaceConnection(Tensor oldInput, Tensor newInput, Operator op); + + Operator cloneOperator(Operator op, TensorVec inputs, TensorVec outputs) { + auto opClone = op->clone(inputs, outputs); + addOperatorAndConnect(opClone); + return opClone; + } const TensorVec &getTensors() const { return tensors; } const OpVec &getOperators() const { return ops; } @@ -79,6 +100,8 @@ class GraphObj : public Object { return ret; } + bool checkValid() const; + private: /** * @brief Add reverse connections and Op relationship in ctor. diff --git a/include/core/graph_match.h b/include/core/graph_match.h new file mode 100644 index 00000000..e81857a9 --- /dev/null +++ b/include/core/graph_match.h @@ -0,0 +1,108 @@ +#pragma once + +#include "core/graph.h" +namespace infini { +class SubGraphObj : public GraphObj { + TensorVec ins; // inputs from outer predecessors, orders are appointed. + TensorVec outs; // outputs to outer successors, orders are appointed. + + public: + SubGraphObj(Runtime runtime, const TensorVec &inputs); + void setOutputs(const TensorVec &tensors) { outs = tensors; } + TensorVec getInputsFromOutside() const { return ins; } + TensorVec getOutputs2Outside() const { return outs; } + bool isInputFromOutside(Tensor t) const { + return std::find(ins.begin(), ins.end(), t) != ins.end(); + } + bool isOutput2Outside(Tensor t) const { + return std::find(outs.begin(), outs.end(), t) != outs.end(); + } + bool isHead(const Operator &op) const { + for (auto in : ins) { + auto ops = in->getTargets(); + if (std::find(ops.begin(), ops.end(), op) != ops.end()) + return true; + } + return false; + }; + bool isTail(const Operator &op) const { + for (auto out : outs) { + if (op == out->getSource()) + return true; + } + return false; + } +}; +using SubGraph = Ref; + +// Describe a match for subgraph replacement. +class GraphMatchObj { + std::unordered_set ops; + std::unordered_map opMap; // anchor->pattern + std::unordered_map opMapRevese; // pattern->anchor + std::unordered_map tensorMap; // pattern->anchor + SubGraph pattern; + + public: + GraphMatchObj(SubGraph pattern) : pattern(pattern) {} + Ref clone(); + void addOp(const Operator &anchorOp, const Operator &patternOp); + bool hasContained(const Operator &op) const { return opMap.count(op) > 0; } + bool hasMatched(const Operator &op) const { + return opMapRevese.count(op) > 0; + } + + Tensor getAnchorByPattern(const Tensor &t) { + IT_ASSERT(tensorMap.count(t) > 0); + return tensorMap.at(t); + } + + Operator getAnchorByPattern(const Operator &op) { + IT_ASSERT(opMapRevese.count(op) > 0); + return opMapRevese.at(op); + } + + TensorVec getInputs() const; + TensorVec getOutputs() const; + std::unordered_set getOps() const { return ops; } + std::string toString() const; + + private: + void recordOutsideTensorMap(const Operator &patternOp, + const Operator &anchorOp); +}; +using MatchGraph = Ref; + +class SubGraphRewriter { + SubGraph pattern; + Graph graph; + + public: + SubGraphRewriter(Graph g) : graph(g) {} + vector findMatch(const SubGraph &pattern); + void replaceSubGraph(const SubGraph &pattern, const SubGraph &replacement); + TensorVec addSubGraph(const SubGraph &pattern, const TensorVec &inputs); + + private: + void removeSubGraph(MatchGraph match); + bool MatchNode(const Operator &a, const Operator &b, bool isHead, + bool isTail) const; + OpLists matchInCandidates(const OpVec &ops, const Operator &opDst, + bool isHead, bool isTail); + bool findMatch(const MatchGraph &lastMatched, const Operator &opLastMatched, + const Operator &opDst, vector &matched); + bool findMatch2(const MatchGraph &lastMatched, + const Operator &opLastMatched, const Operator &opDst, + vector &matched); + void updateMatchedGraph(const MatchGraph &lastMatched, OpLists &opMatched, + vector &gMatched, Operator dst); + + bool checkReplacement(const SubGraph &pattern, const SubGraph &other) const; + bool checkReplacement(const TensorVec &left, const TensorVec &right) const; + bool isReplacable(const Tensor &l, const Tensor &r) const; + bool checkOverlapsWithPreviousMatch( + const MatchGraph &match, + const std::unordered_set &nodesToDelete) const; + bool checkMatchValid(const MatchGraph &match) const; +}; +}; // namespace infini diff --git a/include/core/operator.h b/include/core/operator.h index f584b385..87069096 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -1,6 +1,5 @@ #pragma once #include "core/tensor.h" - namespace infini { enum class OpType { @@ -139,7 +138,10 @@ struct OpPerfKey { } }; +class GraphObj; class OperatorObj : public Object { + friend class GraphObj; + protected: OpType type; TensorVec inputs; @@ -189,8 +191,6 @@ class OperatorObj : public Object { IT_ASSERT(i < outputs.size(), "Index exceeded"); return outputs.at(i); } - void addPredecessors(const Operator &op) { predecessors.emplace_back(op); } - void addSuccessors(const Operator &op) { successors.emplace_back(op); } OpVec getPredecessors() const { return wrefs_to_refs(predecessors); } OpVec getSuccessors() const { return wrefs_to_refs(successors); } OpType getOpType() const { return type; } @@ -225,6 +225,12 @@ class OperatorObj : public Object { * and output shapes. */ virtual vector getWorkloadVector() const { IT_TODO_HALT(); } + + void addPredecessors(const Operator &op) { predecessors.emplace_back(op); } + void addSuccessors(const Operator &op) { successors.emplace_back(op); } + void removePredecessors(const Operator &op); + void removeSuccessors(const Operator &op); + void replaceInput(Tensor t1, Tensor t2); }; #define OP_CLONE(OpObj) \ diff --git a/include/core/runtime.h b/include/core/runtime.h index 60b7ad72..53920fdb 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -25,6 +25,7 @@ enum class OpType; using TensorVec = vector; using OpVec = vector; +using OpLists = list; using VType = uint32_t; diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index 3b9bff3f..54d65ffd 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -3,9 +3,11 @@ #include "core/data_type.h" #include "core/object.h" #include "core/runtime.h" - namespace infini { +class GraphObj; class TensorBaseObj : public Object { + friend class GraphObj; + public: // enum TensorType { // Input, @@ -44,13 +46,24 @@ class TensorBaseObj : public Object { DataType getDType() const { return dtype; } Runtime getRuntime() const { return runtime; } - void addTarget(const Operator &op) { targets.emplace_back(op); } - void setSource(const Operator &op) { source = op; } + // std::pair getOutputOfWithIndex(); bool hasTarget() const { return !targets.empty(); } OpVec getTargets() const { return wrefs_to_refs(targets); } Operator getSource() const { return source.lock(); } + + private: + void addTarget(const Operator &op) { targets.emplace_back(op); } + void setSource(const Operator &op) { source = op; } + void removeTarget(const Operator &op) { + for (auto itr = targets.begin(); itr != targets.end();) { + if (itr->lock() == op) + itr = targets.erase(itr); + else + ++itr; + } + } // std::pair getSourceWithIndex(); // bool setScalar(VType val) { diff --git a/include/operators/batch_norm.h b/include/operators/batch_norm.h index fbee21fd..76f1eff4 100644 --- a/include/operators/batch_norm.h +++ b/include/operators/batch_norm.h @@ -35,7 +35,6 @@ class BatchNormObj : public OperatorObj { float eps = 1e-5, bool training = false); OP_CLONE(BatchNormObj); optional> inferShape(const TensorVec &inputs) const override; - std::string toString() const override; // output size will be 3 when training diff --git a/include/operators/split.h b/include/operators/split.h index e2f9f4a3..61aa43a2 100644 --- a/include/operators/split.h +++ b/include/operators/split.h @@ -48,4 +48,4 @@ class SplitObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/core/graph.cc b/src/core/graph.cc index 800f4ca5..3bc525d4 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -143,6 +143,71 @@ OpVec GraphObj::getComputeOps() const { if (op->isComputeOp()) opList.emplace_back(op); return opList; -}; +} + +void GraphObj::deleteConnection(Tensor tensor, Operator op) { + // if op is target + IT_ASSERT(std::find(tensor->getTargets().begin(), + tensor->getTargets().end(), + op) != tensor->getTargets().end()); + tensor->removeTarget(op); + if (tensor->getSource()) { + tensor->getSource()->removeSuccessors(op); + op->removePredecessors(tensor->getSource()); + } +} + +// add op as a target +void GraphObj::addConnection(Tensor tensor, Operator op) { + tensor->addTarget(op); + if (tensor->getSource()) { + tensor->getSource()->addSuccessors(op); + op->addPredecessors(tensor->getSource()); + } +} + +void GraphObj::replaceConnection(Tensor oldTensor, Tensor newTensor, + Operator op) { + // op is a target of old tensor + IT_ASSERT(std::find(oldTensor->getTargets().begin(), + oldTensor->getTargets().end(), + op) != oldTensor->getTargets().end()); + addConnection(newTensor, op); + deleteConnection(oldTensor, op); + op->replaceInput(oldTensor, newTensor); +} + +// tensor's "source" and "target" must be in "ops". +// tensor has no "source" and no "target" must not exist. +// "inputs" or "outputs" of operators must be in "tensors" +// "predecessors" and "successors" of an operator of "ops" must be in "ops". +bool GraphObj::checkValid() const { + for (auto tensor : tensors) { + IT_ASSERT(!(tensor->getTargets().size() == 0 && + nullptr == tensor->getSource())); + for (auto op : tensor->getTargets()) { + IT_ASSERT(std::find(ops.begin(), ops.end(), op) != ops.end()); + } + auto op = tensor->getSource(); + IT_ASSERT(!(op && std::find(ops.begin(), ops.end(), op) == ops.end())); + } + for (auto op : ops) { + for (auto tensor : op->getInputs()) { + IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) != + tensors.end()); + } + for (auto tensor : op->getOutputs()) { + IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) != + tensors.end()); + } + for (auto pre : op->getPredecessors()) { + IT_ASSERT(std::find(ops.begin(), ops.end(), pre) != ops.end()); + } + for (auto suc : op->getSuccessors()) { + IT_ASSERT(std::find(ops.begin(), ops.end(), suc) != ops.end()); + } + } + return true; +} } // namespace infini diff --git a/src/core/graph_match.cc b/src/core/graph_match.cc new file mode 100644 index 00000000..d2b35e0c --- /dev/null +++ b/src/core/graph_match.cc @@ -0,0 +1,465 @@ +#include "core/graph_match.h" + +namespace infini { +Ref GraphMatchObj::clone() { + auto newObj = make_ref(pattern); + newObj->ops = ops; + newObj->opMap = opMap; + newObj->opMapRevese = opMapRevese; + newObj->tensorMap = tensorMap; + return newObj; +} + +void GraphMatchObj::addOp(const Operator &anchorOp, const Operator &patternOp) { + ops.emplace(anchorOp); + opMap.emplace(anchorOp, patternOp); + opMapRevese.emplace(patternOp, anchorOp); + recordOutsideTensorMap(patternOp, anchorOp); +} + +TensorVec GraphMatchObj::getInputs() const { + TensorVec ret; + for (auto t : pattern->getInputsFromOutside()) { + IT_ASSERT(tensorMap.count(t) > 0); + ret.push_back(tensorMap.at(t)); + } + return ret; +} + +TensorVec GraphMatchObj::getOutputs() const { + TensorVec ret; + for (auto t : pattern->getOutputs2Outside()) { + IT_ASSERT(tensorMap.count(t) > 0); + ret.push_back(tensorMap.at(t)); + } + return ret; +} + +std::string GraphMatchObj::toString() const { + std::ostringstream oss; + oss << "MatchGraph operators:\n"; + for (const auto &op : ops) { + vector preds, succs; + for (auto &o : op->getPredecessors()) + preds.emplace_back(o->getGuid()); + for (auto &o : op->getSuccessors()) + succs.emplace_back(o->getGuid()); + oss << "OP " << op->getGuid(); + oss << ", pred " << vecToString(preds); + oss << ", succ " << vecToString(succs); + oss << ", " << op << "\n"; + } + return oss.str(); +} + +// if the input pattern tensor is from outside,find the +// corresponding input anchor tensor,and record. +void GraphMatchObj::recordOutsideTensorMap(const Operator &patternOp, + const Operator &anchorOp) { + for (size_t i = 0; i < patternOp->getInputs().size(); ++i) { + if (pattern->isInputFromOutside(patternOp->getInputs(i))) + tensorMap.emplace(patternOp->getInputs(i), anchorOp->getInputs(i)); + } + for (size_t i = 0; i < patternOp->getOutputs().size(); ++i) { + if (pattern->isOutput2Outside(patternOp->getOutput(i))) + tensorMap.emplace(patternOp->getOutput(i), anchorOp->getOutput(i)); + } +} + +SubGraphObj::SubGraphObj(Runtime runtime, const TensorVec &inputs) + : GraphObj(runtime), ins(inputs) { + for (auto t : ins) + tensors.push_back(t); +} + +vector SubGraphRewriter::findMatch(const SubGraph &pattern) { + this->pattern = pattern; + vector matches; + bool firstHead = true, retStatus = true; + for (auto input : pattern->getInputsFromOutside()) { + auto inputOf = input->getTargets(); + for (auto opHead : inputOf) { + if (std::find(pattern->getOperators().begin(), + pattern->getOperators().end(), + opHead) == pattern->getOperators().end()) + continue; // not belongs to pattern + if (opHead->getPredecessors().size() > 0) // not a head + continue; + if (firstHead) { + firstHead = false; + if (!findMatch(nullptr, nullptr, opHead, matches)) { + retStatus = false; + break; + } + } else { + if (!findMatch2(nullptr, nullptr, opHead, matches)) { + retStatus = false; + break; + } + } + } + if (!retStatus) + break; + } + + vector ret; + for (auto match : matches) { + if (checkMatchValid(match)) + ret.push_back(match); + } + return ret; +} + +bool SubGraphRewriter::findMatch(const MatchGraph &gLastMatch, + const Operator &opLastMatch, + const Operator &opPattern, + vector &gMatch) { + OpVec candidates = + opLastMatch ? opLastMatch->getSuccessors() : graph->getOperators(); + OpLists nodesMatch = + matchInCandidates(candidates, opPattern, pattern->isHead(opPattern), + pattern->isTail(opPattern)); + + IT_ASSERT(nodesMatch.size() <= 1 || !opLastMatch); + updateMatchedGraph(gLastMatch, nodesMatch, gMatch, opPattern); + + if (nodesMatch.size() == 0) { + return false; + } + + // dst is matched, process successors recursively + for (auto successorPattern : opPattern->getSuccessors()) { + + bool bRet = false; + if (opLastMatch) { + IT_ASSERT(nodesMatch.size() == 1); + if (gLastMatch->hasMatched(successorPattern)) + continue; + bRet = findMatch(gLastMatch, nodesMatch.front(), successorPattern, + gMatch); + } else { + IT_ASSERT(nodesMatch.size() == gMatch.size()); + auto tmp1 = gMatch; + auto itr1 = nodesMatch.begin(); + auto itr2 = gMatch.begin(); + for (; itr1 != nodesMatch.end() && itr2 != gMatch.end(); ++itr2) { + if (findMatch(*itr2, *itr1, successorPattern, tmp1)) { + bRet = true; + ++itr1; + } else + itr1 = nodesMatch.erase(itr1); + } + gMatch = tmp1; + } + // not found,return false + if (!bRet) { + return false; + } + } + return true; +} + +bool SubGraphRewriter::findMatch2(const MatchGraph &gLastMatch, + const Operator &opLastMatch, + const Operator &opPattern, + vector &matches) { + vector curMatches; + for (auto match : matches) { + OpVec candidates = + opLastMatch ? opLastMatch->getSuccessors() : graph->getOperators(); + // filter candiates in matches + for (auto itr2 = candidates.begin(); itr2 != candidates.end();) { + if (match->hasContained( + *itr2)) // already belonged to the matched sub graph + itr2 = candidates.erase(itr2); + else + ++itr2; + } + + OpLists nodesMatch = matchInCandidates( + candidates, opPattern, opPattern->getPredecessors().size() == 0, + opPattern->getSuccessors().size() == 0); + + // no match nodes found, do not add the match to curMatches, continue + if (nodesMatch.size() == 0) { + continue; + } + + for (auto node : nodesMatch) { + auto curMatch = match->clone(); + curMatch->addOp(node, opPattern); // anchor and pattern + + // add to curMatches + curMatches.push_back(curMatch); + + // dst is matched, process successors recursively + for (auto successorPattern : opPattern->getSuccessors()) { + if (match->hasMatched(successorPattern)) // has already matched + continue; + if (!findMatch(curMatch, node, successorPattern, curMatches)) { + // curMatch has been removed from curMatches in + // "findMatch",so just break + break; + } + } + } + } + matches = curMatches; + return true; +} + +OpLists SubGraphRewriter::matchInCandidates(const OpVec &ops, + const Operator &opPattern, + bool isHead, bool isTail) { + OpLists ret; + for (auto op : ops) { + if (MatchNode(opPattern, op, isHead, isTail)) + ret.push_back(op); + } + return ret; +} + +bool SubGraphRewriter::MatchNode(const Operator &a, const Operator &b, + bool isHead, bool isTail) const { + if (a->getOpType() != b->getOpType()) + return false; + if (a->hash() != b->hash()) + return false; + + if (!isHead) + if (a->getPredecessors().size() != b->getPredecessors().size()) + return false; + + if (!isTail) + if (a->getSuccessors().size() != b->getSuccessors().size()) + return false; + return true; +}; + +void SubGraphRewriter::updateMatchedGraph(const MatchGraph &gLastMatch, + OpLists &opMatch, + vector &gMatch, + Operator opPattern) { + if (opMatch.size() == 0) { + if (nullptr != gLastMatch) { + auto pos = std::find(gMatch.begin(), gMatch.end(), gLastMatch); + IT_ASSERT(pos != gMatch.end()); + gMatch.erase(pos); + } + } else { + // anchor is a head + if (nullptr == gLastMatch) { + for (auto op : opMatch) { + auto match = make_ref(pattern); + match->addOp(op, opPattern); + gMatch.push_back(match); + } + } else { + IT_ASSERT(opMatch.size() == 1); + gLastMatch->addOp(opMatch.front(), opPattern); + } + } +} + +bool SubGraphRewriter::checkOverlapsWithPreviousMatch( + const MatchGraph &match, + const std::unordered_set &nodesToDelete) const { + for (auto op : match->getOps()) { + if (nodesToDelete.count(op) > 0) + return false; + } + return true; +} + +bool SubGraphRewriter::checkMatchValid(const MatchGraph &match) const { + for (auto t : pattern->getInputsFromOutside()) { + auto tAnchor = match->getAnchorByPattern(t); + // the corrresponding precessor must not belong to the match + auto preOpAnchor = tAnchor->getSource(); + if (preOpAnchor && match->hasContained(preOpAnchor)) { + return false; + } + } + // check connections + for (auto opPattern : pattern->getOperators()) { + auto opAnchor = match->getAnchorByPattern(opPattern); + for (auto prePattern : opPattern->getPredecessors()) { + auto preAnchor = match->getAnchorByPattern(prePattern); + auto ops = opAnchor->getPredecessors(); + if (std::find(ops.begin(), ops.end(), preAnchor) == ops.end()) + return false; + ops = preAnchor->getSuccessors(); + if (std::find(ops.begin(), ops.end(), opAnchor) == ops.end()) + return false; + } + } + return true; +} + +// replace all sub graphs which matched subA with subB in g +void SubGraphRewriter::replaceSubGraph(const SubGraph &pattern, + const SubGraph &replacement) { + IT_ASSERT(checkReplacement(pattern, replacement)); + this->pattern = pattern; + + // find matches in graph. + auto matches = findMatch(pattern); + + std::unordered_set nodesToDelete; + map replaceMap; + map replaceMapReverse; + for (auto match : matches) { + // matches may overlap with eachother. if some operator has been in + // another folded match,we must skip this one + if (!checkOverlapsWithPreviousMatch(match, nodesToDelete)) + continue; + + auto inputs = match->getInputs(); + for (auto &input : inputs) { + if (replaceMap.count(input) > 0) + input = replaceMap[input]; + } + auto outputs = match->getOutputs(); + + // first, remove old successors for input + for (auto input : inputs) { + for (auto op : input->getTargets()) { + if (match->hasContained(op)) { + graph->deleteConnection(input, op); + } + } + } + + // second, insert replacement sub graph to graph. + auto newOutputs = addSubGraph(replacement, inputs); + + // check replaced outputs and record + IT_ASSERT(outputs.size() == newOutputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + IT_ASSERT(isReplacable(outputs[i], newOutputs[i])); + replaceMap.emplace(outputs[i], newOutputs[i]); + replaceMapReverse.emplace(newOutputs[i], outputs[i]); + } + + // third, change connections for new output + for (auto output : outputs) { + auto successors = output->getTargets(); + for (auto successor : successors) { + auto newOutput = replaceMap[output]; + graph->replaceConnection(output, newOutput, successor); + } + } + + // record ops need to delete + for (auto op : match->getOps()) + nodesToDelete.insert(op); + + // remove match from graph + for (auto op : match->getOps()) { + for (auto tensor : op->getInputs()) { + if (replaceMapReverse.count(tensor) > 0) + tensor = replaceMapReverse[tensor]; + if (std::find(inputs.begin(), inputs.end(), tensor) == + inputs.end()) { + graph->removeTensor(tensor); + } + } + for (auto tensor : op->getOutputs()) { + graph->removeTensor(tensor); + } + graph->removeOperator(op); + } + + IT_ASSERT(graph->checkValid()); + } +} + +// "inputs" must be tensors in original graph +TensorVec SubGraphRewriter::addSubGraph(const SubGraph &g, + const TensorVec &inputs) { + // check inputs + for (auto input : inputs) { + auto tensors = graph->getTensors(); + IT_ASSERT(std::find(tensors.begin(), tensors.end(), input) != + tensors.end()); + } + + // check compatible with sub graph + auto ins = g->getInputsFromOutside(); + IT_ASSERT(checkReplacement(ins, inputs)); + + std::map tensorMap; + for (size_t i = 0; i < ins.size(); ++i) { + tensorMap.emplace(ins[i], inputs[i]); + } + + for (auto t : g->getTensors()) { + if (tensorMap.find(t) == tensorMap.end()) { + auto tClone = graph->addTensor(t->getDims(), t->getDType()); + tensorMap.emplace(t, tClone); + } + } + + for (auto op : g->getOperators()) { + TensorVec inputs, outputs; + for (auto t : op->getInputs()) { + inputs.push_back(tensorMap.at(t)); + } + for (auto t : op->getOutputs()) { + outputs.push_back(tensorMap.at(t)); + } + graph->cloneOperator(op, inputs, outputs); + } + + TensorVec out; + for (auto t : g->getOutputs2Outside()) { + out.push_back(tensorMap[t]); + } + return out; +} + +void SubGraphRewriter::removeSubGraph(MatchGraph match) { + TensorVec inputs = match->getInputs(); + + for (auto op : match->getOps()) { + for (auto tensor : op->getInputs()) { + if (std::find(inputs.begin(), inputs.end(), tensor) == + inputs.end()) { + graph->removeTensor(tensor); + } + } + + for (auto tensor : op->getOutputs()) { + graph->removeTensor(tensor); + } + graph->removeOperator(op); + } +} + +// inputs and outputs must be appointed. +bool SubGraphRewriter::checkReplacement(const SubGraph &pattern, + const SubGraph &other) const { + return checkReplacement(pattern->getInputsFromOutside(), + other->getInputsFromOutside()) && + checkReplacement(pattern->getOutputs2Outside(), + other->getOutputs2Outside()) && + pattern->getInputsFromOutside().size() != 0 && + pattern->getOutputs2Outside().size() != 0; +} + +bool SubGraphRewriter::checkReplacement(const TensorVec &left, + const TensorVec &right) const { + if (left.size() != right.size()) + return false; + for (size_t i = 0; i < left.size(); ++i) { + if (!isReplacable(left[i], right[i])) + return false; + } + return true; +} + +bool SubGraphRewriter::isReplacable(const Tensor &l, const Tensor &r) const { + return (l->getDType() == r->getDType() && l->getDims() == r->getDims()); +} + +} // namespace infini diff --git a/src/core/operator.cc b/src/core/operator.cc index 51568f8f..e707d94e 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -37,6 +37,32 @@ bool OperatorObj::isMemBoundOp() const { type == OpType::Transpose; } +void OperatorObj::removePredecessors(const Operator &op) { + for (auto it = predecessors.begin(); it != predecessors.end();) { + if (it->lock() == op) + it = predecessors.erase(it); + else + ++it; + } +} + +void OperatorObj::removeSuccessors(const Operator &op) { + for (auto it = successors.begin(); it != successors.end();) { + if (it->lock() == op) + it = successors.erase(it); + else + ++it; + } +} + +void OperatorObj::replaceInput(Tensor t1, Tensor t2) { + for (auto itr = inputs.begin(); itr != inputs.end(); ++itr) { + if (*itr == t1) { + *itr = t2; + } + } +} + OpPerfKey OperatorObj::getOpPerfKey() const { auto workloadVector = getWorkloadVector(); // Calculate hash of workload, i.e. hash with shape. This is different from diff --git a/src/operators/conv.cc b/src/operators/conv.cc index 413c0122..78f0a891 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -41,7 +41,7 @@ vector ConvBaseObj::getWorkloadVector() const { } vector ConvBaseObj::getOpAttrVector() const { - IT_TODO_HALT(); // should padding mode / ph+pw be in attrs? + // IT_TODO_HALT(); // should padding mode / ph+pw be in attrs? return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw}; } diff --git a/src/operators/pooling.cc b/src/operators/pooling.cc index 4b276db1..c1a2b0d9 100644 --- a/src/operators/pooling.cc +++ b/src/operators/pooling.cc @@ -45,7 +45,6 @@ vector PoolingObj::getWorkloadVector() const { } vector PoolingObj::getOpAttrVector() const { - IT_TODO_HALT(); return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw}; } diff --git a/test/core/test_graph_replace.cc b/test/core/test_graph_replace.cc new file mode 100644 index 00000000..5959bc7d --- /dev/null +++ b/test/core/test_graph_replace.cc @@ -0,0 +1,422 @@ +#include "core/blob.h" +#include "core/graph_match.h" +#include "core/runtime.h" +#include "operators/concat.h" +#include "operators/conv.h" +#include "operators/element_wise.h" +#include "operators/extend.h" +#include "operators/pad.h" +#include "operators/pooling.h" +#include "operators/reduce_mean.h" +#include "operators/slice.h" +#include "operators/split.h" +#include "operators/unary.h" +#include "test.h" +namespace infini { +// hrnet48 head match conv-relu +TEST(SubGraphRewriter, subGraphMatch1) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 3, 244, 244}, DataType::UInt32); + Tensor w0 = g->addTensor({64, 3, 3, 3}, DataType::UInt32); + auto conv = g->addOp(i0, w0, nullptr); + auto relu = g->addOp(conv->getOutput(), nullptr); + + auto w1 = g->addTensor({64, 64, 3, 3}, DataType::UInt32); + auto conv1 = g->addOp(relu->getOutput(0), w1, nullptr); + auto relu1 = g->addOp(conv1->getOutput(), nullptr); + + auto w2 = g->addTensor({64, 64, 1, 1}, DataType::UInt32); + auto conv2 = g->addOp(relu1->getOutput(0), w2, nullptr); + auto relu2 = g->addOp(conv2->getOutput(), nullptr); + + auto w3 = g->addTensor({256, 64, 1, 1}, DataType::UInt32); + auto conv3 = g->addOp(relu1->getOutput(0), w3, nullptr); + + auto w4 = g->addTensor({64, 64, 3, 3}, DataType::UInt32); + auto conv4 = g->addOp(relu2->getOutput(0), w4, nullptr); + auto relu4 = g->addOp(conv4->getOutput(), nullptr); + + Tensor si0 = + make_ref(Shape{1, 64, 112, 112}, DataType::UInt32, runtime); + SubGraph subG = make_ref(runtime, TensorVec{si0}); + Tensor sw0 = subG->addTensor({64, 64, 3, 3}, DataType::UInt32); + auto sconv0 = subG->addOp(si0, sw0, nullptr); + auto srelu0 = subG->addOp(sconv0->getOutput(), nullptr); + subG->setOutputs(srelu0->getOutputs()); + + SubGraphRewriter v(g); + vector subgs = v.findMatch(subG); + + EXPECT_TRUE(subgs.size() == 2); +} + +TEST(MatchGraph, single_input) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + // subG0 + Tensor si0 = + make_ref(Shape{1, 96, 28, 28}, DataType::UInt32, runtime); + SubGraph subG = make_ref(runtime, TensorVec{si0}); + { + auto srelu0 = subG->addOp(si0, nullptr); + auto sw0 = subG->addTensor({96, 96, 3, 3}, DataType::UInt32); + auto sconv0 = subG->addOp(srelu0->getOutput(0), sw0, nullptr); + auto srelu1 = subG->addOp(sconv0->getOutput(), nullptr); + auto sw1 = subG->addTensor({96, 96, 3, 3}, DataType::UInt32); + auto sconv1 = subG->addOp(srelu1->getOutput(0), sw1, nullptr); + auto sadd0 = subG->addOp(sconv1->getOutput(0), + srelu0->getOutput(0), nullptr); + subG->setOutputs({sadd0->getOutput()}); + } + // subG1 + Tensor si00 = + make_ref(Shape{1, 48, 56, 56}, DataType::UInt32, runtime); + SubGraph subG1 = make_ref(runtime, TensorVec{si00}); + { + auto srelu0 = subG1->addOp(si00, nullptr); + auto sw0 = subG1->addTensor({48, 48, 3, 3}, DataType::UInt32); + auto sconv0 = subG1->addOp(srelu0->getOutput(0), sw0, nullptr); + auto srelu1 = subG1->addOp(sconv0->getOutput(), nullptr); + auto sw1 = subG1->addTensor({48, 48, 3, 3}, DataType::UInt32); + auto sconv1 = subG1->addOp(srelu1->getOutput(0), sw1, nullptr); + auto sadd0 = subG1->addOp(sconv1->getOutput(0), + srelu0->getOutput(0), nullptr); + subG1->setOutputs({sadd0->getOutput()}); + } + + // graph + Graph g = make_ref(runtime); + SubGraphRewriter v(g); + + Tensor i0 = g->addTensor({1, 256, 56, 56}, DataType::UInt32); + auto relu0 = g->addOp(i0, nullptr); + + Tensor w0 = g->addTensor({96, 256, 3, 3}, DataType::UInt32); + auto conv0 = + g->addOp(relu0->getOutput(0), w0, nullptr, 1, 1, 2, 2); + + auto o0 = v.addSubGraph(subG, {conv0->getOutput(0)}); + auto o1 = v.addSubGraph(subG, o0); + auto o2 = v.addSubGraph(subG, o1); + auto o3 = v.addSubGraph(subG, o2); + auto relu4 = g->addOp(o3[0], nullptr); + + Tensor w10 = g->addTensor({48, 256, 3, 3}, DataType::UInt32); + auto conv10 = g->addOp(relu0->getOutput(0), w10, nullptr); + auto o10 = v.addSubGraph(subG1, {conv10->getOutput(0)}); + auto o11 = v.addSubGraph(subG1, o10); + auto o12 = v.addSubGraph(subG1, o11); + auto o13 = v.addSubGraph(subG1, o12); + auto relu10 = g->addOp(o13[0], nullptr); + Tensor w1 = g->addTensor({96, 48, 3, 3}, DataType::UInt32); + auto conv1 = + g->addOp(relu10->getOutput(), w1, nullptr, 1, 1, 2, 2); + auto add1 = + g->addOp(relu4->getOutput(), conv1->getOutput(), nullptr); + + auto o4 = v.addSubGraph(subG, TensorVec{add1->getOutput(0)}); + + EXPECT_EQ(g->getOperators().size(), 52); + vector subgs = v.findMatch(subG); + EXPECT_TRUE(subgs.size() == 5); + + vector subgs1 = v.findMatch(subG1); + EXPECT_TRUE(subgs1.size() == 4); + + // test replace + Tensor sii0 = + make_ref(Shape{1, 96, 28, 28}, DataType::UInt32, runtime); + SubGraph subG2 = make_ref(runtime, TensorVec{sii0}); + { + auto srelu0 = subG2->addOp(sii0, nullptr); + auto sw0 = subG2->addTensor({96, 96, 3, 3}, DataType::UInt32); + auto sconv0 = subG2->addOp(srelu0->getOutput(0), sw0, nullptr); + subG2->setOutputs(sconv0->getOutputs()); + } + + v.replaceSubGraph(subG, subG2); + EXPECT_EQ(g->getOperators().size(), 37); +} + +TEST(MatchGraph, multi_input) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + // subG0 + Tensor i0 = + make_ref(Shape{3, 4, 5, 2}, DataType::UInt32, runtime); + Tensor i1 = make_ref(Shape{3, 4, 5}, DataType::UInt32, runtime); + SubGraph subG = make_ref(runtime, TensorVec{i0, i1}); + { + auto reduce0 = + subG->addOp(i0, nullptr, vector{3}, false); + auto sub0 = subG->addOp(reduce0->getOutput(0), i1, nullptr); + subG->setOutputs(sub0->getOutputs()); + } + + SubGraph replaceG = make_ref( + runtime, TensorVec{i0->clone(runtime), i1->clone(runtime)}); + { + auto reduce0 = + replaceG->addOp(replaceG->getInputsFromOutside()[0], + nullptr, vector{3}, false); + auto sub0 = replaceG->addOp(reduce0->getOutput(0), + replaceG->getInputsFromOutside()[1], + nullptr); + replaceG->setOutputs(sub0->getOutputs()); + } + + Graph g = make_ref(runtime); + SubGraphRewriter v(g); + { + Tensor i0 = g->addTensor({3, 4, 5, 2}, DataType::UInt32); + Tensor i1 = g->addTensor({3, 4, 5, 2}, DataType::UInt32); + auto add0 = g->addOp(i0, i1, nullptr); + auto relu0 = g->addOp(add0->getOutput(), nullptr); + auto reduce0 = g->addOp(relu0->getOutput(), nullptr, + vector{3}, false); + auto o0 = + v.addSubGraph(subG, {add0->getOutput(), reduce0->getOutput()}); + + Tensor i2 = g->addTensor({3, 4, 5}, DataType::UInt32); + auto pow0 = g->addOp(o0[0], i2, nullptr); + + Tensor i3 = g->addTensor({3, 4, 5, 2}, DataType::UInt32); + auto reduce1 = + g->addOp(i3, nullptr, vector{3}, false); + auto sub0 = g->addOp(reduce1->getOutput(0), pow0->getOutput(0), + nullptr); + + auto matches = v.findMatch(subG); + EXPECT_EQ(2, matches.size()); + + auto div0 = g->addOp(reduce1->getOutput(0), i2, nullptr); + auto add1 = + g->addOp(sub0->getOutput(), div0->getOutput(), nullptr); + matches = v.findMatch(subG); + EXPECT_EQ(1, matches.size()); + + // two matched subgraphs overlaped,so only replaced one sub graph + v.replaceSubGraph(subG, replaceG); + EXPECT_EQ(1, v.findMatch(replaceG).size()); + } +} + +TEST(MatchGraph, multi_output) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + // subg0 + Tensor i = + make_ref(Shape{1, 192, 71, 71}, DataType::UInt32, runtime); + SubGraph subg0 = make_ref(runtime, TensorVec{i}); + { + auto maxpool = + subg0->addOp(i, nullptr, 3, 3, 0, 0, 0, 0, 2, 2); + Tensor w0 = subg0->addTensor(Shape{64, 192, 1, 1}, DataType::UInt32); + auto conv0 = subg0->addOp(maxpool->getOutput(0), w0, nullptr); + auto relu0 = subg0->addOp(conv0->getOutput(0), nullptr); + + auto pad = subg0->addOp(maxpool->getOutput(0), nullptr, + vector{0, 0, 1, 1, 0, 0, 1, 1}, + std::nullopt); + auto avgpool = subg0->addOp(pad->getOutput(0), nullptr, 3, + 3, 0, 0, 0, 0, 1, 1); + subg0->setOutputs( + TensorVec{relu0->getOutput(0), avgpool->getOutput(0)}); + } + + SubGraph subg1 = + make_ref(runtime, TensorVec{i->clone(runtime)}); + { + auto avgpool = subg1->addOp( + subg1->getInputsFromOutside()[0], nullptr, 3, 3, 0, 0, 0, 0, 2, 2); + + auto relu0 = subg1->addOp(avgpool->getOutput(0), nullptr); + + auto split0 = + subg1->addOp(avgpool->getOutput(0), std::nullopt, 1, 3); + subg1->setOutputs(TensorVec{split0->getOutput(1), relu0->getOutput(0)}); + } + + Graph g = make_ref(runtime); + SubGraphRewriter v(g); + { + auto input = g->cloneTensor(i); + auto outs = v.addSubGraph(subg0, {input}); + EXPECT_EQ(2, outs.size()); + Tensor w0 = g->addTensor(Shape{96, 64, 3, 3}, DataType::UInt32); + auto conv0 = g->addOp(outs[0], w0, nullptr, 1, 1); + auto relu0 = g->addOp(conv0->getOutput(0), nullptr); + + Tensor w1 = g->addTensor(Shape{96, 96, 3, 3}, DataType::UInt32); + auto conv1 = g->addOp(relu0->getOutput(), w1, nullptr, 1, 1); + auto relu1 = g->addOp(conv1->getOutput(0), nullptr); + + Tensor w2 = g->addTensor(Shape{32, 192, 1, 1}, DataType::UInt32); + auto conv2 = g->addOp(outs[1], w2, nullptr); + auto relu2 = g->addOp(conv2->getOutput(0), nullptr); + + Tensor i0 = g->addTensor(Shape{1, 64, 35, 35}, DataType::UInt32); + Tensor i1 = g->addTensor(Shape{1, 64, 35, 35}, DataType::UInt32); + auto concat = g->addOp( + TensorVec{i0, i1, relu1->getOutput(), relu2->getOutput()}, nullptr, + 1); + auto o = concat->getOutput(); + EXPECT_TRUE((o->getDims() == Shape{1, 256, 35, 35})); + } + + auto matches = v.findMatch(subg0); + EXPECT_EQ(1, matches.size()); + + v.replaceSubGraph(subg0, subg1); + auto matches2 = v.findMatch(subg1); + EXPECT_EQ(1, matches2.size()); +} + +// gcn +TEST(MatchGraph, multi_input_output) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + // subg0 + Tensor i0 = + make_ref(Shape{1, 64, 112, 112}, DataType::UInt32, runtime); + Tensor i1 = + make_ref(Shape{1, 64, 56, 56}, DataType::UInt32, runtime); + SubGraph subg0 = make_ref(runtime, TensorVec{i0, i1}); + { + auto slice = subg0->addOp(i0, nullptr, vector{0, 0}, + vector{55, 55}, + vector{2, 3}, std::nullopt); + auto relu0 = subg0->addOp(slice->getOutput(0), nullptr); + Tensor w0 = subg0->addTensor(Shape{256, 64, 1, 1}, DataType::UInt32); + auto conv0 = subg0->addOp(relu0->getOutput(0), w0, nullptr); + + auto conv1 = subg0->addOp(i1, w0, nullptr); + auto add = subg0->addOp(conv0->getOutput(0), + conv1->getOutput(0), nullptr); + + auto relu1 = subg0->addOp(add->getOutput(0), nullptr); + Tensor w2 = subg0->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32); + auto conv2 = subg0->addOp(relu1->getOutput(0), w2, nullptr); + auto maxpool = subg0->addOp(relu1->getOutput(0), nullptr, 3, + 3, 0, 0, 0, 0, 2, 2); + subg0->setOutputs( + TensorVec{conv2->getOutput(0), maxpool->getOutput(0)}); + } + + SubGraph subg1 = make_ref(runtime, TensorVec{i1, i0}); + { + auto slice = subg1->addOp(i0, nullptr, vector{0, 0}, + vector{55, 55}, + vector{2, 3}, std::nullopt); + auto relu0 = subg1->addOp(slice->getOutput(0), nullptr); + Tensor w0 = subg1->addTensor(Shape{256, 64, 1, 1}, DataType::UInt32); + auto conv0 = subg1->addOp(relu0->getOutput(0), w0, nullptr); + + auto conv1 = subg1->addOp(i1, w0, nullptr); + auto add = subg1->addOp(conv1->getOutput(0), + conv0->getOutput(0), nullptr); + + auto relu1 = subg1->addOp(add->getOutput(0), nullptr); + Tensor w2 = subg1->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32); + auto conv2 = subg1->addOp(relu1->getOutput(0), w2, nullptr); + auto maxpool = subg1->addOp(relu1->getOutput(0), nullptr, 3, + 3, 0, 0, 0, 0, 2, 2); + subg1->setOutputs( + TensorVec{maxpool->getOutput(0), conv2->getOutput(0)}); + } + + SubGraph subg2 = make_ref(runtime, TensorVec{i0, i1}); + { + auto extend = subg2->addOp(i0, nullptr, 1, 3); + + auto slice = subg2->addOp( + extend->getOutput(0), nullptr, vector{0, 0}, + vector{55, 55}, vector{2, 3}, std::nullopt); + + auto extend1 = subg2->addOp(i1, nullptr, 1, 3); + auto add = subg2->addOp(extend1->getOutput(0), + slice->getOutput(0), nullptr); + + auto relu1 = subg2->addOp(add->getOutput(0), nullptr); + Tensor w2 = subg2->addTensor(Shape{128, 256, 1, 1}, DataType::UInt32); + auto conv2 = subg2->addOp(relu1->getOutput(0), w2, nullptr); + auto avgpool = subg2->addOp(relu1->getOutput(0), nullptr, 3, + 3, 0, 0, 0, 0, 2, 2); + subg2->setOutputs( + TensorVec{conv2->getOutput(0), avgpool->getOutput(0)}); + } + + Graph g = make_ref(runtime); + SubGraphRewriter v(g); + { + auto i = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32); + auto relu = g->addOp(i, nullptr); + auto maxPool = g->addOp(relu->getOutput(0), nullptr, 3, 3, + 0, 0, 1, 1, 2, 2); + auto out0 = + v.addSubGraph(subg0, {relu->getOutput(0), maxPool->getOutput(0)}); + auto out1 = + v.addSubGraph(subg1, {maxPool->getOutput(0), relu->getOutput(0)}); + EXPECT_EQ(2, out0.size()); + EXPECT_EQ(2, out1.size()); + auto div = g->addOp(out0[0], out1[1], nullptr); + auto sub = g->addOp(out0[1], out1[0], nullptr); + } + + EXPECT_EQ(2, v.findMatch(subg0).size()); + EXPECT_EQ(2, v.findMatch(subg1).size()); + v.replaceSubGraph(subg0, subg2); + EXPECT_EQ(v.findMatch(subg2).size(), 2); +} + +/* One Node having two or more successors is not supported yet. +TEST(MatchGraph, same_successor) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + // subg0 + Tensor i0 = + make_ref(Shape{1, 64, 112, 112}, DataType::UInt32, runtime); + Tensor i1 = + make_ref(Shape{1, 64, 112, 112}, DataType::UInt32, runtime); + SubGraph subg0 = make_ref(runtime, TensorVec{i0, i1}); + { + auto add0 = subg0->addOp(i0, i1, nullptr); + auto add1 = subg0->addOp(add0->getOutput(0), i1, nullptr); + auto add2 = subg0->addOp(add0->getOutput(0), i1, nullptr); + + auto mul = subg0->addOp(add1->getOutput(0), i1, nullptr); + auto div = subg0->addOp(add2->getOutput(0), i1, nullptr); + + auto sub = + subg0->addOp(mul->getOutput(0), div->getOutput(0), nullptr); + + subg0->setOutputs(TensorVec{sub->getOutput(0)}); + } + + // pattern + SubGraph pattern1 = make_ref(runtime, TensorVec{i0, i1}); + { + auto add0 = pattern1->addOp(i0, i1, nullptr); + auto add1 = pattern1->addOp(add0->getOutput(0), i1, nullptr); + auto div = pattern1->addOp(add1->getOutput(0), i1, nullptr); + pattern1->setOutputs(TensorVec{add0->getOutput(0), div->getOutput(0)}); + } + + // pattern + SubGraph pattern2 = make_ref(runtime, TensorVec{i0, i1}); + { + auto add0 = pattern2->addOp(i0, i1, nullptr); + auto add1 = pattern2->addOp(add0->getOutput(0), i1, nullptr); + pattern2->setOutputs(TensorVec{add0->getOutput(0), add1->getOutput(0)}); + } + + Graph g = make_ref(runtime); + SubGraphRewriter v(g); + { + i0 = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32); + i1 = g->addTensor(Shape{1, 64, 112, 112}, DataType::UInt32); + auto out0 = v.addSubGraph(subg0, {i0, i1}); + } + + EXPECT_EQ(1, v.findMatch(pattern1).size()); + EXPECT_EQ(2, v.findMatch(pattern2).size()); + v.replaceSubGraph(pattern2, pattern1); + EXPECT_EQ(v.findMatch(pattern2).size(), 2); +}*/ +} // namespace infini