From 7893ae0ccaf34bdd160ca33308ff3624eab47e9e Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 15 Feb 2023 11:28:49 +0800 Subject: [PATCH] =?UTF-8?q?opt:=20=E4=BC=98=E5=8C=96=20PadObj=20=E5=92=8C?= =?UTF-8?q?=20SplitObj=20=E6=9E=84=E9=80=A0=E5=99=A8=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/operators/pad.cc | 18 +++++++++--------- src/operators/split.cc | 6 +----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/operators/pad.cc b/src/operators/pad.cc index f3e219d6..7e914f8e 100644 --- a/src/operators/pad.cc +++ b/src/operators/pad.cc @@ -5,19 +5,19 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output, const vector &_pads, const optional> &axis) : OperatorObj(OpType::Pad, {input}, {output}) { - if (axis == std::nullopt) + if (!axis) pads = _pads; else { - int nAxis = (*axis).size(); - IT_ASSERT((int)_pads.size() == nAxis * 2); - int nDims = input->getDims().size(); - vector tmp(nDims * 2, 0); + auto nAxis = (*axis).size(); + IT_ASSERT(_pads.size() == nAxis * 2); + auto nDims = input->getDims().size(); + pads = vector(nDims * 2, 0); - for (int i = 0; i < nAxis; ++i) { - tmp[(*axis)[i]] = _pads[i]; - tmp[(*axis)[i] + nDims] = _pads[i + nAxis]; + for (size_t i = 0; i < nAxis; ++i) { + auto j = (*axis)[i]; + pads[j] = _pads[i]; + pads[j + nDims] = _pads[i + nAxis]; } - pads = tmp; } IT_ASSERT(checkValid(graph)); } diff --git a/src/operators/split.cc b/src/operators/split.cc index 38c0ffbc..eb602417 100644 --- a/src/operators/split.cc +++ b/src/operators/split.cc @@ -5,7 +5,7 @@ namespace infini { SplitObj::SplitObj(GraphObj *graph, Tensor input, std::optional outputs, int dim, int num) : OperatorObj(OpType::Split, {input}, - ((!outputs) ? TensorVec{nullptr} : (*outputs))), + ((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))), dim(dim), num(num), ratio({}) { int dimSize = input->getDims().at(dim); int pieceSize = dimSize / num; @@ -17,10 +17,6 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, } else ratio = std::vector(num, pieceSize); - if (!outputs) { - TensorVec tmp(num, nullptr); - this->outputs = tmp; - } IT_ASSERT(checkValid(graph)); }