forked from jiuyuan/InfiniTensor
opt: 优化 PadObj 和 SplitObj 构造器实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
bb0e7540cc
commit
7893ae0cca
|
@ -5,19 +5,19 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
const vector<int> &_pads,
|
||||
const optional<const vector<int>> &axis)
|
||||
: OperatorObj(OpType::Pad, {input}, {output}) {
|
||||
if (axis == std::nullopt)
|
||||
if (!axis)
|
||||
pads = _pads;
|
||||
else {
|
||||
int nAxis = (*axis).size();
|
||||
IT_ASSERT((int)_pads.size() == nAxis * 2);
|
||||
int nDims = input->getDims().size();
|
||||
vector<int> tmp(nDims * 2, 0);
|
||||
auto nAxis = (*axis).size();
|
||||
IT_ASSERT(_pads.size() == nAxis * 2);
|
||||
auto nDims = input->getDims().size();
|
||||
pads = vector<int>(nDims * 2, 0);
|
||||
|
||||
for (int i = 0; i < nAxis; ++i) {
|
||||
tmp[(*axis)[i]] = _pads[i];
|
||||
tmp[(*axis)[i] + nDims] = _pads[i + nAxis];
|
||||
for (size_t i = 0; i < nAxis; ++i) {
|
||||
auto j = (*axis)[i];
|
||||
pads[j] = _pads[i];
|
||||
pads[j + nDims] = _pads[i + nAxis];
|
||||
}
|
||||
pads = tmp;
|
||||
}
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ namespace infini {
|
|||
SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
||||
std::optional<TensorVec> outputs, int dim, int num)
|
||||
: OperatorObj(OpType::Split, {input},
|
||||
((!outputs) ? TensorVec{nullptr} : (*outputs))),
|
||||
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
|
||||
dim(dim), num(num), ratio({}) {
|
||||
int dimSize = input->getDims().at(dim);
|
||||
int pieceSize = dimSize / num;
|
||||
|
@ -17,10 +17,6 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
|||
} else
|
||||
ratio = std::vector<int>(num, pieceSize);
|
||||
|
||||
if (!outputs) {
|
||||
TensorVec tmp(num, nullptr);
|
||||
this->outputs = tmp;
|
||||
}
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue