opt: 优化 SliceObj 构造器实现

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 16:44:08 +08:00
parent 341cf1f943
commit f9d0076a86
2 changed files with 18 additions and 19 deletions

View File

@ -20,14 +20,14 @@ class SliceObj : public OperatorObj {
* list which has the same length with axis. * list which has the same length with axis.
* @param ends The end position to slice at certain axes. `ends` is a list * @param ends The end position to slice at certain axes. `ends` is a list
* which has the same length with axis. * which has the same length with axis.
* @param axis The dimensions to slice. If `axis` is empty, it is set to [0, * @param axes The dimensions to slice. If `axis` is empty, it is set to [0,
* 1, ..., d-1], where d is the number of dimensions of the input tensor. * 1, ..., d-1], where d is the number of dimensions of the input tensor.
* @param steps The step to slice at certain axes. `step` is a list which * @param steps The step to slice at certain axes. `step` is a list which
* has the same length with axis. * has the same length with axis.
*/ */
SliceObj(GraphObj *graph, Tensor input, Tensor output, SliceObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &starts, const vector<int> &ends, const vector<int> &starts, const vector<int> &ends,
const optional<vector<int>> &axis, const optional<vector<int>> &axes,
const optional<vector<int>> &steps); const optional<vector<int>> &steps);
OP_CLONE(SliceObj); OP_CLONE(SliceObj);

View File

@ -3,34 +3,33 @@
namespace infini { namespace infini {
SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output, SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &starts, const vector<int> &ends, const vector<int> &starts, const vector<int> &ends,
const optional<vector<int>> &axis, const optional<vector<int>> &axes,
const optional<vector<int>> &steps) const optional<vector<int>> &steps)
: OperatorObj(OpType::Slice, {input}, {output}) { : OperatorObj(OpType::Slice, {input}, {output}) {
if (steps != std::nullopt) if (steps)
IT_TODO_HALT(); IT_TODO_HALT();
IT_ASSERT(starts.size() == ends.size()); IT_ASSERT(starts.size() == ends.size());
if (axis == std::nullopt) { if (!axes) {
this->starts = starts; this->starts = starts;
this->ends = ends; this->ends = ends;
} else { } else {
int nAxis = (*axis).size(); auto nAxis = (*axes).size();
IT_ASSERT((int)starts.size() == nAxis); IT_ASSERT(starts.size() == nAxis);
int nDims = input->getDims().size(); auto dims = input->getDims();
vector<int> tmpS(nDims, 0), tmpE; this->starts = vector<int>(dims.size(), 0);
for (int i = 0; i < nDims; ++i) { this->ends.resize(dims.size());
tmpE.emplace_back(input->getDims()[i] - 1); std::transform(dims.begin(), dims.end(), this->ends.begin(),
} [](auto x) { return x - 1; });
for (int i = 0; i < nAxis; ++i) { for (size_t j = 0; j < nAxis; ++j) {
if ((*axis)[i] < 0) auto i = (*axes)[j];
if (i < 0)
IT_TODO_HALT(); IT_TODO_HALT();
tmpS[(*axis)[i]] = starts[i]; this->starts[i] = starts[j];
tmpE[(*axis)[i]] = ends[i]; this->ends[i] = ends[j];
} }
this->starts = tmpS;
this->ends = tmpE;
} }
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }