forked from jiuyuan/InfiniTensor
opt: 优化 PadObj 和 SplitObj 构造器实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
bb0e7540cc
commit
7893ae0cca
|
@ -5,19 +5,19 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const vector<int> &_pads,
|
const vector<int> &_pads,
|
||||||
const optional<const vector<int>> &axis)
|
const optional<const vector<int>> &axis)
|
||||||
: OperatorObj(OpType::Pad, {input}, {output}) {
|
: OperatorObj(OpType::Pad, {input}, {output}) {
|
||||||
if (axis == std::nullopt)
|
if (!axis)
|
||||||
pads = _pads;
|
pads = _pads;
|
||||||
else {
|
else {
|
||||||
int nAxis = (*axis).size();
|
auto nAxis = (*axis).size();
|
||||||
IT_ASSERT((int)_pads.size() == nAxis * 2);
|
IT_ASSERT(_pads.size() == nAxis * 2);
|
||||||
int nDims = input->getDims().size();
|
auto nDims = input->getDims().size();
|
||||||
vector<int> tmp(nDims * 2, 0);
|
pads = vector<int>(nDims * 2, 0);
|
||||||
|
|
||||||
for (int i = 0; i < nAxis; ++i) {
|
for (size_t i = 0; i < nAxis; ++i) {
|
||||||
tmp[(*axis)[i]] = _pads[i];
|
auto j = (*axis)[i];
|
||||||
tmp[(*axis)[i] + nDims] = _pads[i + nAxis];
|
pads[j] = _pads[i];
|
||||||
|
pads[j + nDims] = _pads[i + nAxis];
|
||||||
}
|
}
|
||||||
pads = tmp;
|
|
||||||
}
|
}
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ namespace infini {
|
||||||
SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
||||||
std::optional<TensorVec> outputs, int dim, int num)
|
std::optional<TensorVec> outputs, int dim, int num)
|
||||||
: OperatorObj(OpType::Split, {input},
|
: OperatorObj(OpType::Split, {input},
|
||||||
((!outputs) ? TensorVec{nullptr} : (*outputs))),
|
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
|
||||||
dim(dim), num(num), ratio({}) {
|
dim(dim), num(num), ratio({}) {
|
||||||
int dimSize = input->getDims().at(dim);
|
int dimSize = input->getDims().at(dim);
|
||||||
int pieceSize = dimSize / num;
|
int pieceSize = dimSize / num;
|
||||||
|
@ -17,10 +17,6 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
||||||
} else
|
} else
|
||||||
ratio = std::vector<int>(num, pieceSize);
|
ratio = std::vector<int>(num, pieceSize);
|
||||||
|
|
||||||
if (!outputs) {
|
|
||||||
TensorVec tmp(num, nullptr);
|
|
||||||
this->outputs = tmp;
|
|
||||||
}
|
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue