forked from jiuyuan/InfiniTensor
Add: remove independent tensors in graph
This commit is contained in:
parent
9ce21200c4
commit
d8a133684e
|
@ -108,6 +108,11 @@ class GraphObj : public Object {
|
|||
|
||||
bool checkValid() const;
|
||||
|
||||
/// @brief If a tensor has no source and garget, it is independent and
|
||||
/// removed from the graph.
|
||||
/// @return The number of removed tensors.
|
||||
int removeIndependentTensors();
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Add reverse connections and Op relationship in ctor.
|
||||
|
|
|
@ -113,7 +113,8 @@ class GraphHandlerObj {
|
|||
void data_malloc() { g->dataMalloc(); }
|
||||
|
||||
void run() { g->getRuntime()->run(g); }
|
||||
Graph getGraph() const { return g; }
|
||||
|
||||
Graph getGraph() const;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -103,7 +103,8 @@ class OnnxStub:
|
|||
else:
|
||||
adapt = node.input[0]
|
||||
|
||||
if len(node.input) > 2:
|
||||
# HACK: ignore bias
|
||||
if len(node.input) > 3:
|
||||
bias = "{}-bias".format(node.output[0])
|
||||
reshape = "{}-reshape".format(node.output[0])
|
||||
tensors[bias] = ans.handler.conv(
|
||||
|
@ -393,11 +394,16 @@ class OnnxStub:
|
|||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "PRelu":
|
||||
tensors[node.output[0]] = ans.handler.pRelu(
|
||||
# HACK: replace PRelu with Relu
|
||||
tensors[node.output[0]] = ans.handler.relu(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
# tensors[node.output[0]] = ans.handler.pRelu(
|
||||
# tensors[node.input[0]],
|
||||
# tensors[node.input[1]],
|
||||
# tensors.get(node.output[0]),
|
||||
# )
|
||||
elif node.op_type == "Clip":
|
||||
tensors[node.output[0]] = ans.handler.clip(
|
||||
tensors[node.input[0]],
|
||||
|
|
|
@ -229,4 +229,14 @@ bool GraphObj::checkValid() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
int GraphObj::removeIndependentTensors() {
|
||||
TensorVec newTensors;
|
||||
for (const auto &t : tensors)
|
||||
if (!t->getTargets().empty() || t->getSource())
|
||||
newTensors.emplace_back(t);
|
||||
auto ret = tensors.size() - newTensors.size();
|
||||
tensors = newTensors;
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -332,4 +332,12 @@ static DataType dtype_repr_convert(int dtype) {
|
|||
}
|
||||
}
|
||||
|
||||
Graph GraphHandlerObj::getGraph() const {
|
||||
int nRemoved = g->removeIndependentTensors();
|
||||
if (nRemoved > 0)
|
||||
std::cout << "Removed " << nRemoved << " independent tensors"
|
||||
<< std::endl;
|
||||
return g;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -77,14 +77,6 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
|
|||
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
||||
OpVec computeOps = in_graph->getComputeOps();
|
||||
IT_ASSERT(computeOps.size() == 1);
|
||||
// HACK: remove this
|
||||
if (auto op = as<ConvTransposed2dNHWCObj>(computeOps[0]); !op)
|
||||
return;
|
||||
|
||||
// if (infini::Graph g = transformTConv1x1(computeOps[0])) {
|
||||
// out_graphs.emplace_back(g);
|
||||
// return;
|
||||
// }
|
||||
if (Graph g = transformConvtransposed1x1(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
|
@ -104,6 +96,11 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
// // return;
|
||||
// // }
|
||||
|
||||
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC};
|
||||
if (opSet.count(computeOps[0]->getOpType()) == 0)
|
||||
return;
|
||||
|
||||
|
||||
auto expr = opToExpression(computeOps[0]);
|
||||
if (!expr)
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue