forked from jiuyuan/InfiniTensor
Add documentation for operators.
This commit is contained in:
parent
cca4d2a491
commit
26be533faa
|
@ -2,7 +2,11 @@
|
|||
#include "core/operator.h"
|
||||
#include <assert.h>
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief General to band matrix multiplication, which is used for Longformer
|
||||
* model. See https://arxiv.org/pdf/2004.05150.pdf for detail.
|
||||
*
|
||||
*/
|
||||
class G2BMMObj : public OperatorObj {
|
||||
private:
|
||||
// to be implemented
|
||||
|
@ -13,15 +17,17 @@ class G2BMMObj : public OperatorObj {
|
|||
|
||||
public:
|
||||
/**
|
||||
* @brief This comments show how operators is defined in InfiniTensor. The
|
||||
* constructor can create output tensors for the operator or not, which
|
||||
* depends on `graph`.
|
||||
* @brief Construct a new G2BMM object.
|
||||
*
|
||||
* @param graph If graph is not empty, create outputs in the constructor.
|
||||
* Otherwise, check the provided shape with the results of `inferShape` in
|
||||
* `checkValid`.
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param A The input tensor.
|
||||
* @param B The input tensor.
|
||||
* @param C C is the output of G2BMM. If outputs are going to be created in
|
||||
* the constructor, C should be an empty Ref.
|
||||
* @param width The width of the attention window.
|
||||
* @param dilation The dilation of the attention window.
|
||||
* @param bias The bias tensor.
|
||||
* @param act The activation.
|
||||
*/
|
||||
G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int width,
|
||||
const int dilation, Tensor bias = nullptr,
|
||||
|
|
|
@ -2,7 +2,12 @@
|
|||
#include "core/operator.h"
|
||||
#include <assert.h>
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief General band matrix multiplication. See
|
||||
* https://cscproxy.mpi-magdeburg.mpg.de/mpcsc/benner/pub/brdeq-cle2014.pdf for
|
||||
* detail.
|
||||
*
|
||||
*/
|
||||
class GBMMObj : public OperatorObj {
|
||||
private:
|
||||
int dilation;
|
||||
|
@ -12,15 +17,16 @@ class GBMMObj : public OperatorObj {
|
|||
|
||||
public:
|
||||
/**
|
||||
* @brief This comments show how operators is defined in InfiniTensor. The
|
||||
* constructor can create output tensors for the operator or not, which
|
||||
* depends on `graph`.
|
||||
* @brief Construct a new GBMM object.
|
||||
*
|
||||
* @param graph If graph is not empty, create outputs in the constructor.
|
||||
* Otherwise, check the provided shape with the results of `inferShape` in
|
||||
* `checkValid`.
|
||||
* @param C C is the output of GBMM. If outputs are going to be created in
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param A The input tensor.
|
||||
* @param B The input tensor.
|
||||
* @param C C is the output of G2BMM. If outputs are going to be created in
|
||||
* the constructor, C should be an empty Ref.
|
||||
* @param dilation The dilation of the attention window.
|
||||
* @param bias The bias tensor.
|
||||
* @param act The activation.
|
||||
*/
|
||||
GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int dilation,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
|
|
|
@ -2,11 +2,34 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief See https://arxiv.org/abs/1502.03167 for the detail of batch
|
||||
* normalization.
|
||||
*
|
||||
*/
|
||||
class BatchNormObj : public OperatorObj {
|
||||
float momentum, eps;
|
||||
bool training;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new BatchNorm object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor of BatchNorm. For image data, the input
|
||||
* shape is usually [N, C, H, W].
|
||||
* @param output The output tensor of BatchNorm, which should have the same
|
||||
* shape as the input tensor.
|
||||
* @param mean The mean tensor, which has a shape of [C].
|
||||
* @param var The var tensor, which has a shape of [C].
|
||||
* @param scale The scale tensor, which has a shape of [C].
|
||||
* @param bias The bias tensor, which has a shape of [C].
|
||||
* @param momentum Factor used in computing the running mean and variance.
|
||||
* Default is 0.9.
|
||||
* @param eps The epsilon value to use to avoid division by zero. Default is
|
||||
* 1e-5.
|
||||
* @param training Set to true when used for training.
|
||||
*/
|
||||
BatchNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean,
|
||||
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
|
||||
float eps = 1e-5, bool training = false);
|
||||
|
|
|
@ -2,10 +2,23 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Concatenate several tensors into one. All the input tensors should
|
||||
* have the same shape except for the concatenated dimension.
|
||||
*
|
||||
*/
|
||||
class ConcatObj : public OperatorObj {
|
||||
int dim;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Concat object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param inputs The input tensors to be concatenated.
|
||||
* @param output Concatenated tensor.
|
||||
* @param dim The dimension to concatenate on.
|
||||
*/
|
||||
ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim);
|
||||
OP_CLONE(ConcatObj);
|
||||
|
||||
|
|
|
@ -2,7 +2,30 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief Convolution. Currently this operator only supports 2-D convolution.
|
||||
* This is the base class for convolution and transposed convolution.
|
||||
* The input tensor has four dimensions, called N (batch), C (channel), H
|
||||
* (height), and W (width) respectively; The weight tensor has four dimensions,
|
||||
* called F (number of filters), C (channel), R (height of weight), and S (width
|
||||
* of weight) respectively; The output tensor has four dimensions, called N, F,
|
||||
* H, and W respectively. By default, we take NCHW layout for the input and
|
||||
* output tensors, and FCRS layout for the weight tensor.
|
||||
* Convolutions have three attributes, called padding, stride, and dilation.
|
||||
* Padding is assigned by padding mode or padding size. Padding mode must be
|
||||
* Other, Same, or Valid (see the definition of enum class PaddingMode). Same
|
||||
* means the output has the same shape as the input. Valid means padding size is
|
||||
* 0. Other means padding size is assigned by value ph and pw, denoting the
|
||||
* padding size along height dimension and weight dimension, respectively.
|
||||
* Stride is assigned by sh and sw, denoting the stride along height dimension
|
||||
* and weight dimension, respectively.
|
||||
* Dilation is assigned by dh and dw, denoting the dilation along height
|
||||
* dimension and weight dimension, respectively.
|
||||
* See
|
||||
* https://towardsdatascience.com/types-of-convolutions-in-deep-learning-717013397f4d
|
||||
* for a detailed explanation of convolution.
|
||||
*
|
||||
*/
|
||||
class ConvBaseObj : public OperatorObj {
|
||||
public:
|
||||
// When PaddingMode is Other, ConvObj will use padding size (ph, pw)
|
||||
|
@ -18,7 +41,7 @@ class ConvBaseObj : public OperatorObj {
|
|||
int sh, sw;
|
||||
int dh, dw;
|
||||
PaddingMode padding;
|
||||
// auxiliary attributes. Descripitions stand on a forward perspective,
|
||||
// Auxiliary attributes. Descripitions stand on a forward perspective,
|
||||
// i.e., convTransposed2d is not regarded as the backward of conv2d.
|
||||
int n; // batch size
|
||||
int c; // input/output channel for conv2d/convTransposed2d
|
||||
|
@ -27,10 +50,43 @@ class ConvBaseObj : public OperatorObj {
|
|||
int r, s; // weight shape
|
||||
|
||||
public:
|
||||
// Constructors for explicitly setting padding size
|
||||
/**
|
||||
* @brief Construct a new ConvBase object by explicitly setting padding
|
||||
* size.
|
||||
*
|
||||
* @param opType Indicate if this is a convolution or transposed
|
||||
* convolution.
|
||||
* @param inputs The input, weight and bias tensors. Bias is optional.
|
||||
* FIXME: Split inputs into three parameters, input, weight, and bias.
|
||||
* @param output The output tensor.
|
||||
* @param ph Padding along height dimension.
|
||||
* @param pw Padding along weight dimension.
|
||||
* @param sh Stride along height dimension.
|
||||
* @param sw Stride along weight dimension.
|
||||
* @param dh Dilation along height dimension.
|
||||
* @param dw Dilation along weight dimension.
|
||||
* @param inputInConvFWD To be removed.
|
||||
* @param weightInConvFWD To be removed.
|
||||
*/
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw, const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD);
|
||||
/**
|
||||
* @brief Construct a new ConvBase object by setting padding mode.
|
||||
*
|
||||
* @param opType Indicate if this is a convolution or transposed
|
||||
* convolution.
|
||||
* @param inputs The input, weight and bias tensors. Bias is optional.
|
||||
* FIXME: Split inputs into three parameters, input, weight, and bias.
|
||||
* @param output The output tensor.
|
||||
* @param mode Padding mode, which is set to Other, Same, or Valid.
|
||||
* @param sh Stride along height dimension.
|
||||
* @param sw Stride along weight dimension.
|
||||
* @param dh Dilation along height dimension.
|
||||
* @param dw Dilation along weight dimension.
|
||||
* @param inputInConvFWD To be removed.
|
||||
* @param weightInConvFWD To be removed.
|
||||
*/
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD);
|
||||
|
|
|
@ -2,8 +2,23 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Base class of **binary** element-wise operators.
|
||||
* Unary operators like activations are not the derived classes of
|
||||
* ElementWiseObj.
|
||||
*
|
||||
*/
|
||||
class ElementWiseObj : public OperatorObj {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new ElementWise object
|
||||
*
|
||||
* @param type Operator type.
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input0 The first input tensor.
|
||||
* @param input1 The second input tensor.
|
||||
* @param output The output tensor.
|
||||
*/
|
||||
ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1,
|
||||
Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
|
|
@ -2,10 +2,24 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Copy a tensor along a centain dimension for multiple times.
|
||||
*
|
||||
*/
|
||||
class ExtendObj : public OperatorObj {
|
||||
int dim, num; // copy num times at the dim.
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Extend object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The extened tensor.
|
||||
* @param dim The dimension to extend on.
|
||||
* @param num The number of times to copy when extending. The dimension size
|
||||
* of `dim` becomes `num+1` times after extending.
|
||||
*/
|
||||
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||
int num = 1);
|
||||
OP_CLONE(ExtendObj);
|
||||
|
|
|
@ -3,10 +3,24 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Gather and concatenate given positions on a certain dimension of the
|
||||
* input tensor using an index tensor.
|
||||
*
|
||||
*/
|
||||
class GatherObj : public OperatorObj {
|
||||
int axis;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Gather object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param index The index tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axis The axis to gather on.
|
||||
*/
|
||||
GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output,
|
||||
int axis);
|
||||
OP_CLONE(GatherObj);
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief Matrix multiplication.
|
||||
*
|
||||
*/
|
||||
class MatmulObj : public OperatorObj {
|
||||
private:
|
||||
// InfiniTensor assumes a row-major tensor layout. `transA`=false means
|
||||
|
@ -16,15 +19,19 @@ class MatmulObj : public OperatorObj {
|
|||
|
||||
public:
|
||||
/**
|
||||
* @brief This comments show how operators is defined in InfiniTensor. The
|
||||
* constructor can create output tensors for the operator or not, which
|
||||
* depends on `graph`.
|
||||
* @brief Construct a new Matmul object. This comments show how operators is
|
||||
* defined in InfiniTensor. The constructor can create output tensors for
|
||||
* the operator or not, which depends on `graph`.
|
||||
*
|
||||
* @param graph If graph is not empty, create outputs in the constructor.
|
||||
* Otherwise, check the provided shape with the results of `inferShape` in
|
||||
* `checkValid`.
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param A The input tensor.
|
||||
* @param B The input tensor.
|
||||
* @param C C is the output of Matmul. If outputs are going to be created in
|
||||
* the constructor, C should be an empty Ref.
|
||||
* @param transA If matrix A should be transposed when computing.
|
||||
* @param transB If matrix B should be transposed when computing.
|
||||
* @param bias The bias tensor.
|
||||
* @param act The activation function.
|
||||
*/
|
||||
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
|
||||
bool transA = false, bool transB = false, Tensor bias = nullptr,
|
||||
|
|
|
@ -2,12 +2,27 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Add data at the out side of a tensor.
|
||||
*
|
||||
*/
|
||||
class PadObj : public OperatorObj {
|
||||
// the number of start and end pad values for all dims.
|
||||
vector<int> pads;
|
||||
|
||||
public:
|
||||
// pad for appointed axises,if axis is empty,then pad for all axises.
|
||||
/**
|
||||
* @brief Construct a new Pad object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The padded tensor.
|
||||
* @param pads Add padding elements at the begining and end of each axis.
|
||||
* Suppose that padding axes are [x1, x2, ...], then pads's format is
|
||||
* [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
||||
* @param axis Pad for appointed axes. If axis is empty, pad for all axes.
|
||||
*/
|
||||
PadObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const vector<int> &pads, const optional<const vector<int>> &axis);
|
||||
OP_CLONE(PadObj);
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
/**
|
||||
* @brief The base class for AvgPool and MaxPool.
|
||||
*
|
||||
*/
|
||||
class PoolingObj : public OperatorObj {
|
||||
private:
|
||||
int kh, kw;
|
||||
|
@ -12,6 +15,24 @@ class PoolingObj : public OperatorObj {
|
|||
int n, c, h, w;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Pooling object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param optype Operator type of this pooling operator.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param kh Kernel height.
|
||||
* @param kw Kernel width.
|
||||
* FIXME: Dilated pooling is not supported for many frameworks?
|
||||
* @param dh Dilation at the height dimension.
|
||||
* @param dw Dilation at the width dimension.
|
||||
* FIXME: Auto padding using padding mode.
|
||||
* @param ph Padding at the height dimension.
|
||||
* @param pw Padding at the width dimension.
|
||||
* @param sh Stride at the height dimension.
|
||||
* @param sw Stride at the width dimension.
|
||||
*/
|
||||
PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output,
|
||||
int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw);
|
||||
OP_CLONE(PoolingObj);
|
||||
|
|
|
@ -2,11 +2,24 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Compute the mean of input tensor's elements along certain axes.
|
||||
*
|
||||
*/
|
||||
class ReduceMeanObj : public OperatorObj {
|
||||
set<int> axis; // axis to reduce
|
||||
bool keepDims;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new ReduceMean object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param axis Axes to reduce.
|
||||
* @param keepDims Keep the reduced dimensions or not.
|
||||
*/
|
||||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<const vector<int>> &axis,
|
||||
bool keepDims = true);
|
||||
|
|
|
@ -3,10 +3,22 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Change the shape of the input tensor.
|
||||
*
|
||||
*/
|
||||
class ReshapeObj : public OperatorObj {
|
||||
Shape dims;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Reshape object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param dims The shape of the output tensor.
|
||||
*/
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, const Shape &dims);
|
||||
OP_CLONE(ReshapeObj);
|
||||
|
||||
|
@ -21,9 +33,22 @@ class ReshapeObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Reshape the input tensor into a one-dimensional tensor.
|
||||
* FIXME: Move to an independent file.
|
||||
* FIXME: Different parameter list with ONNX and Pytorch.
|
||||
*
|
||||
*/
|
||||
class FlattenObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Flatten object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output one-dimensional tensor.
|
||||
*/
|
||||
FlattenObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(FlattenObj);
|
||||
|
||||
|
@ -38,9 +63,21 @@ class FlattenObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Copy the input tensor.
|
||||
* FIXME: Move to an independent file.
|
||||
*
|
||||
*/
|
||||
class IdentityObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Identity object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor, which is the same as the input tensor.
|
||||
*/
|
||||
IdentityObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(IdentityObj);
|
||||
|
||||
|
|
|
@ -3,6 +3,11 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Resize the input tensor. See
|
||||
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize for detail.
|
||||
*
|
||||
*/
|
||||
class ResizeObj : public OperatorObj {
|
||||
public:
|
||||
enum class ECoordinateTransMode {
|
||||
|
|
|
@ -2,10 +2,29 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Produce a slice of the input tensor along given dimensions.
|
||||
*
|
||||
*/
|
||||
class SliceObj : public OperatorObj {
|
||||
vector<int> starts, ends; // the start no. and end no. for all dims.
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Slice object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param starts The start position to slice at certain axes. `starts` is a
|
||||
* list which has the same length with axis.
|
||||
* @param ends The end position to slice at certain axes. `ends` is a list
|
||||
* which has the same length with axis.
|
||||
* @param axis The dimensions to slice. If `axis` is empty, it is set to [0,
|
||||
* 1, ..., d-1], where d is the number of dimensions of the input tensor.
|
||||
* @param steps The step to slice at certain axes. `step` is a list which
|
||||
* has the same length with axis.
|
||||
*/
|
||||
SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const vector<int> &starts, const vector<int> &ends,
|
||||
const optional<vector<int>> &axis,
|
||||
|
|
|
@ -2,12 +2,37 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Split a tensor into multiple ones.
|
||||
*
|
||||
*/
|
||||
class SplitObj : public OperatorObj {
|
||||
int dim, num; // split dim;Average split num or outputs size
|
||||
vector<int> ratio; // output dim ratio
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Split object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param outputs The output tensors after splitting.
|
||||
* @param dim The dimension to split.
|
||||
* @param num The number of output tensors. The input tensor is split into
|
||||
* `num` evenly chunk along dimension `dim. The last chunk will be smaller
|
||||
* if the input tensor cannot be evenly split.
|
||||
*/
|
||||
SplitObj(GraphObj *graph, Tensor input, std::optional<TensorVec> outputs,
|
||||
int dim, int num);
|
||||
/**
|
||||
* @brief Construct a new Split object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param outputs The output tensors after splitting.
|
||||
* @param dim The dimension to split.
|
||||
* @param ratio The size of dimension `dim` for the output tensors after
|
||||
* splitting.
|
||||
*/
|
||||
SplitObj(GraphObj *graph, Tensor input, std::optional<TensorVec> outputs,
|
||||
int dim, const vector<int> &ratio);
|
||||
OP_CLONE(SplitObj);
|
||||
|
|
|
@ -2,8 +2,20 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief The base class for unary operators.
|
||||
*
|
||||
*/
|
||||
class UnaryObj : public OperatorObj {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Unary object.
|
||||
*
|
||||
* @param type Operator type.
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
*/
|
||||
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
|
|
Loading…
Reference in New Issue