opt: 优化 SliceObj 构造器实现

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 16:44:08 +08:00
parent 341cf1f943
commit f9d0076a86
2 changed files with 18 additions and 19 deletions

View File

@ -20,14 +20,14 @@ class SliceObj : public OperatorObj {
* list which has the same length with axis.
* @param ends The end position to slice at certain axes. `ends` is a list
* which has the same length with axis.
* @param axis The dimensions to slice. If `axis` is empty, it is set to [0,
* @param axes The dimensions to slice. If `axis` is empty, it is set to [0,
* 1, ..., d-1], where d is the number of dimensions of the input tensor.
* @param steps The step to slice at certain axes. `step` is a list which
* has the same length with axis.
*/
SliceObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &starts, const vector<int> &ends,
const optional<vector<int>> &axis,
const optional<vector<int>> &axes,
const optional<vector<int>> &steps);
OP_CLONE(SliceObj);

View File

@ -3,34 +3,33 @@
namespace infini {
SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &starts, const vector<int> &ends,
const optional<vector<int>> &axis,
const optional<vector<int>> &axes,
const optional<vector<int>> &steps)
: OperatorObj(OpType::Slice, {input}, {output}) {
if (steps != std::nullopt)
if (steps)
IT_TODO_HALT();
IT_ASSERT(starts.size() == ends.size());
if (axis == std::nullopt) {
if (!axes) {
this->starts = starts;
this->ends = ends;
} else {
int nAxis = (*axis).size();
IT_ASSERT((int)starts.size() == nAxis);
auto nAxis = (*axes).size();
IT_ASSERT(starts.size() == nAxis);
int nDims = input->getDims().size();
vector<int> tmpS(nDims, 0), tmpE;
for (int i = 0; i < nDims; ++i) {
tmpE.emplace_back(input->getDims()[i] - 1);
}
auto dims = input->getDims();
this->starts = vector<int>(dims.size(), 0);
this->ends.resize(dims.size());
std::transform(dims.begin(), dims.end(), this->ends.begin(),
[](auto x) { return x - 1; });
for (int i = 0; i < nAxis; ++i) {
if ((*axis)[i] < 0)
for (size_t j = 0; j < nAxis; ++j) {
auto i = (*axes)[j];
if (i < 0)
IT_TODO_HALT();
tmpS[(*axis)[i]] = starts[i];
tmpE[(*axis)[i]] = ends[i];
this->starts[i] = starts[j];
this->ends[i] = ends[j];
}
this->starts = tmpS;
this->ends = tmpE;
}
IT_ASSERT(checkValid(graph));
}