Add: remove independent tensors in graph

This commit is contained in:
Liyan Zheng 2023-04-21 22:57:23 +08:00
parent 9ce21200c4
commit d8a133684e
6 changed files with 39 additions and 12 deletions

View File

@ -108,6 +108,11 @@ class GraphObj : public Object {
bool checkValid() const;
/// @brief If a tensor has no source and garget, it is independent and
/// removed from the graph.
/// @return The number of removed tensors.
int removeIndependentTensors();
private:
/**
* @brief Add reverse connections and Op relationship in ctor.

View File

@ -113,7 +113,8 @@ class GraphHandlerObj {
void data_malloc() { g->dataMalloc(); }
void run() { g->getRuntime()->run(g); }
Graph getGraph() const { return g; }
Graph getGraph() const;
};
} // namespace infini

View File

@ -103,7 +103,8 @@ class OnnxStub:
else:
adapt = node.input[0]
if len(node.input) > 2:
# HACK: ignore bias
if len(node.input) > 3:
bias = "{}-bias".format(node.output[0])
reshape = "{}-reshape".format(node.output[0])
tensors[bias] = ans.handler.conv(
@ -393,11 +394,16 @@ class OnnxStub:
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "PRelu":
tensors[node.output[0]] = ans.handler.pRelu(
# HACK: replace PRelu with Relu
tensors[node.output[0]] = ans.handler.relu(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
# tensors[node.output[0]] = ans.handler.pRelu(
# tensors[node.input[0]],
# tensors[node.input[1]],
# tensors.get(node.output[0]),
# )
elif node.op_type == "Clip":
tensors[node.output[0]] = ans.handler.clip(
tensors[node.input[0]],

View File

@ -229,4 +229,14 @@ bool GraphObj::checkValid() const {
return true;
}
int GraphObj::removeIndependentTensors() {
TensorVec newTensors;
for (const auto &t : tensors)
if (!t->getTargets().empty() || t->getSource())
newTensors.emplace_back(t);
auto ret = tensors.size() - newTensors.size();
tensors = newTensors;
return ret;
}
} // namespace infini

View File

@ -332,4 +332,12 @@ static DataType dtype_repr_convert(int dtype) {
}
}
Graph GraphHandlerObj::getGraph() const {
int nRemoved = g->removeIndependentTensors();
if (nRemoved > 0)
std::cout << "Removed " << nRemoved << " independent tensors"
<< std::endl;
return g;
}
} // namespace infini

View File

@ -77,14 +77,6 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
OpVec computeOps = in_graph->getComputeOps();
IT_ASSERT(computeOps.size() == 1);
// HACK: remove this
if (auto op = as<ConvTransposed2dNHWCObj>(computeOps[0]); !op)
return;
// if (infini::Graph g = transformTConv1x1(computeOps[0])) {
// out_graphs.emplace_back(g);
// return;
// }
if (Graph g = transformConvtransposed1x1(computeOps[0])) {
out_graphs.emplace_back(g);
return;
@ -104,6 +96,11 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
// // return;
// // }
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC};
if (opSet.count(computeOps[0]->getOpType()) == 0)
return;
auto expr = opToExpression(computeOps[0]);
if (!expr)
return;