forked from jiuyuan/InfiniTensor
fix reshape
This commit is contained in:
parent
1c84c3ebb9
commit
4fbeb2b233
|
@ -53,9 +53,9 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
|
|||
// 1: processing Conv-related subgraph
|
||||
// 2: after processing Conv-related subgraph
|
||||
int modelStatus = 0;
|
||||
for (const auto &t : inG->getTensors())
|
||||
if (t->getDims().size() != 4)
|
||||
return nullptr;
|
||||
// for (const auto &t : inG->getTensors())
|
||||
// if (t->getDims().size() != 4)
|
||||
// return nullptr;
|
||||
auto getTensor = [&g, &tensors](const Tensor &inTensor) {
|
||||
auto uid = inTensor->getGuid();
|
||||
if (auto it = tensors.find(uid); it == tensors.end()) {
|
||||
|
@ -98,6 +98,14 @@ Graph convertNCHWtoNHWCModel(Graph inG) {
|
|||
g->addOpWithOutputs<ConvTransposed2dNHWCObj>(
|
||||
inputs[0], weight, outputs[0], ph, pw, sh, sw, dh, dw, oph,
|
||||
opw, group, bias, cOp->getAct());
|
||||
} else if (const auto &rOp = as<ReshapeObj>(op)) {
|
||||
auto dims = rOp->getShape();
|
||||
IT_ASSERT(dims.size() == 4);
|
||||
auto b = dims[1];
|
||||
auto c = dims[2];
|
||||
auto d = dims[3];
|
||||
dims[1] = d; dims[2] = b; dims[3] = c;
|
||||
g->addOpWithOutputs<ReshapeObj>(inputs[0], outputs[0], dims);
|
||||
} else if (const auto &pOp = as<PoolingObj>(op)) {
|
||||
auto t = g->addOp<TransposeObj>(inputs[0], nullptr,
|
||||
vector<int>{0, 2, 3, 1})
|
||||
|
|
Loading…
Reference in New Issue