opt: 优化 PadObj 和 SplitObj 构造器实现

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-15 11:28:49 +08:00
parent bb0e7540cc
commit 7893ae0cca
2 changed files with 10 additions and 14 deletions

View File

@ -5,19 +5,19 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &_pads,
const optional<const vector<int>> &axis)
: OperatorObj(OpType::Pad, {input}, {output}) {
if (axis == std::nullopt)
if (!axis)
pads = _pads;
else {
int nAxis = (*axis).size();
IT_ASSERT((int)_pads.size() == nAxis * 2);
int nDims = input->getDims().size();
vector<int> tmp(nDims * 2, 0);
auto nAxis = (*axis).size();
IT_ASSERT(_pads.size() == nAxis * 2);
auto nDims = input->getDims().size();
pads = vector<int>(nDims * 2, 0);
for (int i = 0; i < nAxis; ++i) {
tmp[(*axis)[i]] = _pads[i];
tmp[(*axis)[i] + nDims] = _pads[i + nAxis];
for (size_t i = 0; i < nAxis; ++i) {
auto j = (*axis)[i];
pads[j] = _pads[i];
pads[j + nDims] = _pads[i + nAxis];
}
pads = tmp;
}
IT_ASSERT(checkValid(graph));
}

View File

@ -5,7 +5,7 @@ namespace infini {
SplitObj::SplitObj(GraphObj *graph, Tensor input,
std::optional<TensorVec> outputs, int dim, int num)
: OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec{nullptr} : (*outputs))),
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
dim(dim), num(num), ratio({}) {
int dimSize = input->getDims().at(dim);
int pieceSize = dimSize / num;
@ -17,10 +17,6 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
} else
ratio = std::vector<int>(num, pieceSize);
if (!outputs) {
TensorVec tmp(num, nullptr);
this->outputs = tmp;
}
IT_ASSERT(checkValid(graph));
}