2022-09-29 10:29:24 +08:00
|
|
|
#pragma once
|
|
|
|
#include "core/operator.h"
|
|
|
|
|
|
|
|
namespace infini {
|
2023-02-13 22:48:20 +08:00
|
|
|
/**
|
|
|
|
* @brief Produce a slice of the input tensor along given dimensions.
|
|
|
|
*
|
|
|
|
*/
|
2022-09-29 10:29:24 +08:00
|
|
|
class SliceObj : public OperatorObj {
|
|
|
|
vector<int> starts, ends; // the start no. and end no. for all dims.
|
|
|
|
|
|
|
|
public:
|
2023-02-13 22:48:20 +08:00
|
|
|
/**
|
|
|
|
* @brief Construct a new Slice object.
|
|
|
|
*
|
|
|
|
* @param graph The computation graph that this operator belongs to.
|
|
|
|
* @param input The input tensor.
|
|
|
|
* @param output The output tensor.
|
|
|
|
* @param starts The start position to slice at certain axes. `starts` is a
|
|
|
|
* list which has the same length with axis.
|
|
|
|
* @param ends The end position to slice at certain axes. `ends` is a list
|
|
|
|
* which has the same length with axis.
|
2023-02-14 16:44:08 +08:00
|
|
|
* @param axes The dimensions to slice. If `axis` is empty, it is set to [0,
|
2023-02-13 22:48:20 +08:00
|
|
|
* 1, ..., d-1], where d is the number of dimensions of the input tensor.
|
|
|
|
* @param steps The step to slice at certain axes. `step` is a list which
|
|
|
|
* has the same length with axis.
|
|
|
|
*/
|
2022-09-29 10:29:24 +08:00
|
|
|
SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
|
|
|
const vector<int> &starts, const vector<int> &ends,
|
2023-02-14 16:44:08 +08:00
|
|
|
const optional<vector<int>> &axes,
|
2022-09-29 10:29:24 +08:00
|
|
|
const optional<vector<int>> &steps);
|
2023-02-12 18:27:52 +08:00
|
|
|
OP_CLONE(SliceObj);
|
2022-09-29 10:29:24 +08:00
|
|
|
|
|
|
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
|
|
|
std::string toString() const override;
|
|
|
|
int numInputs() const override { return 1; }
|
|
|
|
int numOutputs() const override { return 1; }
|
|
|
|
Shape getStart() const { return starts; }
|
|
|
|
|
|
|
|
private:
|
|
|
|
vector<int> getWorkloadVector() const override;
|
|
|
|
vector<int> getOpAttrVector() const override;
|
|
|
|
};
|
2023-02-14 16:44:08 +08:00
|
|
|
} // namespace infini
|