forked from jiuyuan/InfiniTensor
32 lines
1.1 KiB
C++
32 lines
1.1 KiB
C++
#pragma once
|
|
#include "core/operator.h"
|
|
|
|
namespace infini {
|
|
class UnaryObj : public OperatorObj {
|
|
public:
|
|
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
|
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
|
|
|
std::string toString() const override;
|
|
int numInputs() const override { return 1; }
|
|
int numOutputs() const override { return 1; }
|
|
|
|
private:
|
|
vector<int> getWorkloadVector() const override;
|
|
vector<int> getOpAttrVector() const override;
|
|
};
|
|
|
|
#define DEFINE_UNARY_OBJ(prefix, type) \
|
|
class prefix##Obj : public UnaryObj { \
|
|
public: \
|
|
prefix##Obj(GraphObj *graph, Tensor input, Tensor output) \
|
|
: UnaryObj(type, graph, input, output) {} \
|
|
};
|
|
|
|
DEFINE_UNARY_OBJ(Relu, OpType::Relu)
|
|
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
|
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
|
DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
|
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
|
|
}; // namespace infini
|