forked from jiuyuan/InfiniTensor
opt: 优化 SliceObj 构造器实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
341cf1f943
commit
f9d0076a86
|
@ -20,14 +20,14 @@ class SliceObj : public OperatorObj {
|
||||||
* list which has the same length with axis.
|
* list which has the same length with axis.
|
||||||
* @param ends The end position to slice at certain axes. `ends` is a list
|
* @param ends The end position to slice at certain axes. `ends` is a list
|
||||||
* which has the same length with axis.
|
* which has the same length with axis.
|
||||||
* @param axis The dimensions to slice. If `axis` is empty, it is set to [0,
|
* @param axes The dimensions to slice. If `axis` is empty, it is set to [0,
|
||||||
* 1, ..., d-1], where d is the number of dimensions of the input tensor.
|
* 1, ..., d-1], where d is the number of dimensions of the input tensor.
|
||||||
* @param steps The step to slice at certain axes. `step` is a list which
|
* @param steps The step to slice at certain axes. `step` is a list which
|
||||||
* has the same length with axis.
|
* has the same length with axis.
|
||||||
*/
|
*/
|
||||||
SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const vector<int> &starts, const vector<int> &ends,
|
const vector<int> &starts, const vector<int> &ends,
|
||||||
const optional<vector<int>> &axis,
|
const optional<vector<int>> &axes,
|
||||||
const optional<vector<int>> &steps);
|
const optional<vector<int>> &steps);
|
||||||
OP_CLONE(SliceObj);
|
OP_CLONE(SliceObj);
|
||||||
|
|
||||||
|
|
|
@ -3,34 +3,33 @@
|
||||||
namespace infini {
|
namespace infini {
|
||||||
SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
const vector<int> &starts, const vector<int> &ends,
|
const vector<int> &starts, const vector<int> &ends,
|
||||||
const optional<vector<int>> &axis,
|
const optional<vector<int>> &axes,
|
||||||
const optional<vector<int>> &steps)
|
const optional<vector<int>> &steps)
|
||||||
: OperatorObj(OpType::Slice, {input}, {output}) {
|
: OperatorObj(OpType::Slice, {input}, {output}) {
|
||||||
if (steps != std::nullopt)
|
if (steps)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
IT_ASSERT(starts.size() == ends.size());
|
IT_ASSERT(starts.size() == ends.size());
|
||||||
|
|
||||||
if (axis == std::nullopt) {
|
if (!axes) {
|
||||||
this->starts = starts;
|
this->starts = starts;
|
||||||
this->ends = ends;
|
this->ends = ends;
|
||||||
} else {
|
} else {
|
||||||
int nAxis = (*axis).size();
|
auto nAxis = (*axes).size();
|
||||||
IT_ASSERT((int)starts.size() == nAxis);
|
IT_ASSERT(starts.size() == nAxis);
|
||||||
|
|
||||||
int nDims = input->getDims().size();
|
auto dims = input->getDims();
|
||||||
vector<int> tmpS(nDims, 0), tmpE;
|
this->starts = vector<int>(dims.size(), 0);
|
||||||
for (int i = 0; i < nDims; ++i) {
|
this->ends.resize(dims.size());
|
||||||
tmpE.emplace_back(input->getDims()[i] - 1);
|
std::transform(dims.begin(), dims.end(), this->ends.begin(),
|
||||||
}
|
[](auto x) { return x - 1; });
|
||||||
|
|
||||||
for (int i = 0; i < nAxis; ++i) {
|
for (size_t j = 0; j < nAxis; ++j) {
|
||||||
if ((*axis)[i] < 0)
|
auto i = (*axes)[j];
|
||||||
|
if (i < 0)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
tmpS[(*axis)[i]] = starts[i];
|
this->starts[i] = starts[j];
|
||||||
tmpE[(*axis)[i]] = ends[i];
|
this->ends[i] = ends[j];
|
||||||
}
|
}
|
||||||
this->starts = tmpS;
|
|
||||||
this->ends = tmpE;
|
|
||||||
}
|
}
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue