forked from jiuyuan/InfiniTensor
opt: 优化 SliceObj 构造器实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
341cf1f943
commit
f9d0076a86
|
@ -20,14 +20,14 @@ class SliceObj : public OperatorObj {
|
|||
* list which has the same length with axis.
|
||||
* @param ends The end position to slice at certain axes. `ends` is a list
|
||||
* which has the same length with axis.
|
||||
* @param axis The dimensions to slice. If `axis` is empty, it is set to [0,
|
||||
* @param axes The dimensions to slice. If `axis` is empty, it is set to [0,
|
||||
* 1, ..., d-1], where d is the number of dimensions of the input tensor.
|
||||
* @param steps The step to slice at certain axes. `step` is a list which
|
||||
* has the same length with axis.
|
||||
*/
|
||||
SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const vector<int> &starts, const vector<int> &ends,
|
||||
const optional<vector<int>> &axis,
|
||||
const optional<vector<int>> &axes,
|
||||
const optional<vector<int>> &steps);
|
||||
OP_CLONE(SliceObj);
|
||||
|
||||
|
@ -41,4 +41,4 @@ class SliceObj : public OperatorObj {
|
|||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,34 +3,33 @@
|
|||
namespace infini {
|
||||
SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const vector<int> &starts, const vector<int> &ends,
|
||||
const optional<vector<int>> &axis,
|
||||
const optional<vector<int>> &axes,
|
||||
const optional<vector<int>> &steps)
|
||||
: OperatorObj(OpType::Slice, {input}, {output}) {
|
||||
if (steps != std::nullopt)
|
||||
if (steps)
|
||||
IT_TODO_HALT();
|
||||
IT_ASSERT(starts.size() == ends.size());
|
||||
|
||||
if (axis == std::nullopt) {
|
||||
if (!axes) {
|
||||
this->starts = starts;
|
||||
this->ends = ends;
|
||||
} else {
|
||||
int nAxis = (*axis).size();
|
||||
IT_ASSERT((int)starts.size() == nAxis);
|
||||
auto nAxis = (*axes).size();
|
||||
IT_ASSERT(starts.size() == nAxis);
|
||||
|
||||
int nDims = input->getDims().size();
|
||||
vector<int> tmpS(nDims, 0), tmpE;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
tmpE.emplace_back(input->getDims()[i] - 1);
|
||||
}
|
||||
auto dims = input->getDims();
|
||||
this->starts = vector<int>(dims.size(), 0);
|
||||
this->ends.resize(dims.size());
|
||||
std::transform(dims.begin(), dims.end(), this->ends.begin(),
|
||||
[](auto x) { return x - 1; });
|
||||
|
||||
for (int i = 0; i < nAxis; ++i) {
|
||||
if ((*axis)[i] < 0)
|
||||
for (size_t j = 0; j < nAxis; ++j) {
|
||||
auto i = (*axes)[j];
|
||||
if (i < 0)
|
||||
IT_TODO_HALT();
|
||||
tmpS[(*axis)[i]] = starts[i];
|
||||
tmpE[(*axis)[i]] = ends[i];
|
||||
this->starts[i] = starts[j];
|
||||
this->ends[i] = ends[j];
|
||||
}
|
||||
this->starts = tmpS;
|
||||
this->ends = tmpE;
|
||||
}
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue