Add: clone for operators

This commit is contained in:
Liyan Zheng 2022-11-15 21:15:10 +08:00
parent f133f00478
commit e991b3261b
17 changed files with 18 additions and 0 deletions

View File

@ -26,6 +26,7 @@ class G2BMMObj : public OperatorObj {
G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias = nullptr,
ActType act = ActType::None);
OP_CLONE(G2BMMObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -24,6 +24,7 @@ class GBMMObj : public OperatorObj {
*/
GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int dilation,
Tensor bias = nullptr, ActType act = ActType::None);
OP_CLONE(GBMMObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -10,6 +10,7 @@ class BatchNormObj : public OperatorObj {
BatchNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
float eps = 1e-5, bool training = false);
OP_CLONE(BatchNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -77,6 +77,7 @@ class ConvObj : public ConvBaseObj {
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
int dh = 1, int dw = 1, Tensor bias = nullptr,
ActType act = ActType::None);
OP_CLONE(ConvObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }
@ -104,6 +105,7 @@ class ConvTransposed2dObj : public ConvBaseObj {
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
int oph = 0, int opw = 0, int group = 1,
Tensor bias = nullptr, ActType act = ActType::None);
OP_CLONE(ConvTransposed2dObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }

View File

@ -23,6 +23,7 @@ class ElementWiseObj : public OperatorObj {
prefix##Obj(GraphObj *graph, Tensor input0, Tensor input1, \
Tensor output) \
: ElementWiseObj(type, graph, input0, input1, output) {} \
OP_CLONE(prefix##Obj); \
};
DEFINE_ELEMENT_WISE_OBJ(Add, OpType::Add)

View File

@ -8,6 +8,7 @@ class ExtendObj : public OperatorObj {
public:
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
int num = 1);
OP_CLONE(ExtendObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -9,6 +9,7 @@ class GatherObj : public OperatorObj {
public:
GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output,
int axis);
OP_CLONE(GatherObj);
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }

View File

@ -29,6 +29,7 @@ class MatmulObj : public OperatorObj {
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
bool transA = false, bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
OP_CLONE(MatmulObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -17,6 +17,7 @@ class MemBoundObj : public OperatorObj {
const TensorVec &output,
const std::vector<nnet::Tensor> &nnetInputs, nnet::Expr expr,
double exec_time, std::string hint = {});
OP_CLONE(MemBoundObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -10,6 +10,7 @@ class PadObj : public OperatorObj {
// pad for appointed axises,if axis is empty,then pad for all axises.
PadObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &pads, const optional<const vector<int>> &axis);
OP_CLONE(PadObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -14,6 +14,7 @@ class PoolingObj : public OperatorObj {
public:
PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output,
int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw);
OP_CLONE(PoolingObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -10,6 +10,7 @@ class ReduceMeanObj : public OperatorObj {
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &axis,
bool keepDims = true);
OP_CLONE(ReduceMeanObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -8,6 +8,7 @@ class ReshapeObj : public OperatorObj {
public:
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, const Shape &dims);
OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -51,6 +51,7 @@ class ResizeObj : public OperatorObj {
GraphObj *graph, Tensor input, Tensor output,
const std::optional<vector<int>> &axes, Tensor scales, ECoeffMode mode,
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
OP_CLONE(ResizeObj);
vector<DataType> inferDataType(const TensorVec &inputs) const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -10,6 +10,7 @@ class SliceObj : public OperatorObj {
const vector<int> &starts, const vector<int> &ends,
const optional<vector<int>> &axis,
const optional<vector<int>> &steps);
OP_CLONE(SliceObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -10,6 +10,7 @@ class SplitObj : public OperatorObj {
int dim, int num);
SplitObj(GraphObj *graph, Tensor input, std::optional<TensorVec> outputs,
int dim, const vector<int> &ratio);
OP_CLONE(SplitObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -21,6 +21,7 @@ class UnaryObj : public OperatorObj {
public: \
prefix##Obj(GraphObj *graph, Tensor input, Tensor output) \
: UnaryObj(type, graph, input, output) {} \
OP_CLONE(prefix##Obj); \
};
DEFINE_UNARY_OBJ(Relu, OpType::Relu)