diff --git a/include/operators/slice.h b/include/operators/slice.h index e7262552..7aeb0941 100644 --- a/include/operators/slice.h +++ b/include/operators/slice.h @@ -20,14 +20,14 @@ class SliceObj : public OperatorObj { * list which has the same length with axis. * @param ends The end position to slice at certain axes. `ends` is a list * which has the same length with axis. - * @param axis The dimensions to slice. If `axis` is empty, it is set to [0, + * @param axes The dimensions to slice. If `axis` is empty, it is set to [0, * 1, ..., d-1], where d is the number of dimensions of the input tensor. * @param steps The step to slice at certain axes. `step` is a list which * has the same length with axis. */ SliceObj(GraphObj *graph, Tensor input, Tensor output, const vector &starts, const vector &ends, - const optional> &axis, + const optional> &axes, const optional> &steps); OP_CLONE(SliceObj); @@ -41,4 +41,4 @@ class SliceObj : public OperatorObj { vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/operators/slice.cc b/src/operators/slice.cc index 5987531f..e5a5ec8d 100644 --- a/src/operators/slice.cc +++ b/src/operators/slice.cc @@ -3,34 +3,33 @@ namespace infini { SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output, const vector &starts, const vector &ends, - const optional> &axis, + const optional> &axes, const optional> &steps) : OperatorObj(OpType::Slice, {input}, {output}) { - if (steps != std::nullopt) + if (steps) IT_TODO_HALT(); IT_ASSERT(starts.size() == ends.size()); - if (axis == std::nullopt) { + if (!axes) { this->starts = starts; this->ends = ends; } else { - int nAxis = (*axis).size(); - IT_ASSERT((int)starts.size() == nAxis); + auto nAxis = (*axes).size(); + IT_ASSERT(starts.size() == nAxis); - int nDims = input->getDims().size(); - vector tmpS(nDims, 0), tmpE; - for (int i = 0; i < nDims; ++i) { - tmpE.emplace_back(input->getDims()[i] - 1); - } + auto dims = input->getDims(); + this->starts = vector(dims.size(), 0); + this->ends.resize(dims.size()); + std::transform(dims.begin(), dims.end(), this->ends.begin(), + [](auto x) { return x - 1; }); - for (int i = 0; i < nAxis; ++i) { - if ((*axis)[i] < 0) + for (size_t j = 0; j < nAxis; ++j) { + auto i = (*axes)[j]; + if (i < 0) IT_TODO_HALT(); - tmpS[(*axis)[i]] = starts[i]; - tmpE[(*axis)[i]] = ends[i]; + this->starts[i] = starts[j]; + this->ends[i] = ends[j]; } - this->starts = tmpS; - this->ends = tmpE; } IT_ASSERT(checkValid(graph)); }