fix reshape

This commit is contained in:
wanghailu 2023-06-26 11:33:51 +08:00
parent 1c84c3ebb9
commit 4fbeb2b233
1 changed files with 11 additions and 3 deletions

View File

@ -53,9 +53,9 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
// 1: processing Conv-related subgraph
// 2: after processing Conv-related subgraph
int modelStatus = 0;
for (const auto &t : inG->getTensors())
if (t->getDims().size() != 4)
return nullptr;
// for (const auto &t : inG->getTensors())
// if (t->getDims().size() != 4)
// return nullptr;
auto getTensor = [&g, &tensors](const Tensor &inTensor) {
auto uid = inTensor->getGuid();
if (auto it = tensors.find(uid); it == tensors.end()) {
@ -98,6 +98,14 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
g->addOpWithOutputs<ConvTransposed2dNHWCObj>(
inputs[0], weight, outputs[0], ph, pw, sh, sw, dh, dw, oph,
opw, group, bias, cOp->getAct());
} else if (const auto &rOp = as<ReshapeObj>(op)) {
auto dims = rOp->getShape();
IT_ASSERT(dims.size() == 4);
auto b = dims[1];
auto c = dims[2];
auto d = dims[3];
dims[1] = d; dims[2] = b; dims[3] = c;
g->addOpWithOutputs<ReshapeObj>(inputs[0], outputs[0], dims);
} else if (const auto &pOp = as<PoolingObj>(op)) {
auto t = g->addOp<TransposeObj>(inputs[0], nullptr,
vector<int>{0, 2, 3, 1})