opt: 优化 PadObj 和 SplitObj 构造器实现

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-15 11:28:49 +08:00
parent bb0e7540cc
commit 7893ae0cca
2 changed files with 10 additions and 14 deletions

View File

@ -5,19 +5,19 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &_pads, const vector<int> &_pads,
const optional<const vector<int>> &axis) const optional<const vector<int>> &axis)
: OperatorObj(OpType::Pad, {input}, {output}) { : OperatorObj(OpType::Pad, {input}, {output}) {
if (axis == std::nullopt) if (!axis)
pads = _pads; pads = _pads;
else { else {
int nAxis = (*axis).size(); auto nAxis = (*axis).size();
IT_ASSERT((int)_pads.size() == nAxis * 2); IT_ASSERT(_pads.size() == nAxis * 2);
int nDims = input->getDims().size(); auto nDims = input->getDims().size();
vector<int> tmp(nDims * 2, 0); pads = vector<int>(nDims * 2, 0);
for (int i = 0; i < nAxis; ++i) { for (size_t i = 0; i < nAxis; ++i) {
tmp[(*axis)[i]] = _pads[i]; auto j = (*axis)[i];
tmp[(*axis)[i] + nDims] = _pads[i + nAxis]; pads[j] = _pads[i];
pads[j + nDims] = _pads[i + nAxis];
} }
pads = tmp;
} }
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }

View File

@ -5,7 +5,7 @@ namespace infini {
SplitObj::SplitObj(GraphObj *graph, Tensor input, SplitObj::SplitObj(GraphObj *graph, Tensor input,
std::optional<TensorVec> outputs, int dim, int num) std::optional<TensorVec> outputs, int dim, int num)
: OperatorObj(OpType::Split, {input}, : OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec{nullptr} : (*outputs))), ((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
dim(dim), num(num), ratio({}) { dim(dim), num(num), ratio({}) {
int dimSize = input->getDims().at(dim); int dimSize = input->getDims().at(dim);
int pieceSize = dimSize / num; int pieceSize = dimSize / num;
@ -17,10 +17,6 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
} else } else
ratio = std::vector<int>(num, pieceSize); ratio = std::vector<int>(num, pieceSize);
if (!outputs) {
TensorVec tmp(num, nullptr);
this->outputs = tmp;
}
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }