2023-04-18 15:10:33 +08:00
|
|
|
#pragma once
|
|
|
|
#include "core/operator.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
class TransposeObj : public OperatorObj {
|
|
|
|
public:
|
|
|
|
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
|
|
|
vector<int> permute);
|
|
|
|
OP_CLONE(TransposeObj);
|
2023-11-23 13:11:50 +08:00
|
|
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
2023-04-18 15:10:33 +08:00
|
|
|
|
|
|
|
std::string toString() const override;
|
|
|
|
int numInputs() const override { return 1; }
|
|
|
|
int numOutputs() const override { return 1; }
|
|
|
|
std::vector<int> getPermute() const { return transposePermute; }
|
|
|
|
|
|
|
|
private:
|
2023-08-16 21:49:43 +08:00
|
|
|
vector<int> transposePermute;
|
2023-04-18 15:10:33 +08:00
|
|
|
vector<int> getWorkloadVector() const override;
|
|
|
|
vector<int> getOpAttrVector() const override;
|
|
|
|
};
|
2023-11-10 17:58:26 +08:00
|
|
|
|
|
|
|
class DepthToSpaceObj : public OperatorObj {
|
|
|
|
public:
|
|
|
|
DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize,
|
|
|
|
std::string mode);
|
|
|
|
OP_CLONE(DepthToSpaceObj);
|
2023-11-23 13:11:50 +08:00
|
|
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
2023-11-10 17:58:26 +08:00
|
|
|
|
|
|
|
std::string toString() const override;
|
|
|
|
int numInputs() const override { return 1; }
|
|
|
|
int numOutputs() const override { return 1; }
|
|
|
|
int getBlockSize() const { return blockSize; }
|
|
|
|
int getMode() const { return D2SMode; }
|
|
|
|
auto getModeString() const { return D2SModeString; }
|
|
|
|
auto getReshapeDim() const { return reshapeDim; }
|
|
|
|
auto getTransposeDim() const { return transposeDim; }
|
|
|
|
auto getOutDim() const { return outDim; }
|
|
|
|
|
|
|
|
private:
|
|
|
|
int blockSize;
|
|
|
|
int D2SMode;
|
|
|
|
std::string D2SModeString;
|
|
|
|
mutable std::vector<int> reshapeDim = {1, 1, 1, 1, 1, 1};
|
|
|
|
mutable std::vector<int> transposeDim = {1, 1, 1, 1, 1, 1};
|
|
|
|
mutable std::vector<int> outDim = {1, 1, 1, 1};
|
|
|
|
vector<int> getWorkloadVector() const override;
|
|
|
|
vector<int> getOpAttrVector() const override;
|
|
|
|
};
|
|
|
|
|
2023-10-12 10:14:28 +08:00
|
|
|
} // namespace infini
|