InfiniTensor/include/operators/reshape.h

103 lines
2.8 KiB
C
Raw Permalink Normal View History

#pragma once
#include "core/operator.h"
namespace infini {
2023-02-13 22:48:20 +08:00
/**
* @brief Change the shape of the input tensor.
*
*/
class ReshapeObj : public OperatorObj {
Shape dims;
Shape outputShape;
public:
2023-02-13 22:48:20 +08:00
/**
* @brief Construct a new Reshape object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
* @param dims The shape to infer the output shape.
* @param outputShape The real shape of output tensor.
2023-02-13 22:48:20 +08:00
*/
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
inline Shape getShape() const { return outputShape; }
inline Shape getDims() const { return dims; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
2023-02-13 22:48:20 +08:00
/**
* @brief Reshape the input tensor into a one-dimensional tensor.
* FIXME: Move to an independent file.
* FIXME: Different parameter list with ONNX and Pytorch.
*
*/
class FlattenObj : public OperatorObj {
int axis;
public:
2023-02-13 22:48:20 +08:00
/**
* @brief Construct a new Flatten object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output one-dimensional tensor.
*/
FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis);
OP_CLONE(FlattenObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
int getAxis() const { return axis; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
2023-02-13 22:48:20 +08:00
/**
* @brief Copy the input tensor.
* FIXME: Move to an independent file.
*
*/
class IdentityObj : public OperatorObj {
public:
2023-02-13 22:48:20 +08:00
/**
* @brief Construct a new Identity object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor, which is the same as the input tensor.
*/
IdentityObj(GraphObj *graph, Tensor input, Tensor output);
OP_CLONE(IdentityObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini