forked from jiuyuan/InfiniTensor
150 lines
5.2 KiB
C++
150 lines
5.2 KiB
C++
#pragma once
|
|
|
|
#include "core/op_type.h"
|
|
#include "core/tensor.h"
|
|
|
|
namespace infini {
|
|
using KernelAttrs = std::tuple<Device, OpType::underlying_t>;
|
|
|
|
struct OpPerfKey {
|
|
HashType hash;
|
|
OpType::underlying_t opType;
|
|
vector<int> attrs;
|
|
|
|
public:
|
|
// FIXME: default ctor should be deleted but json requires it. Solution:
|
|
// https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types
|
|
OpPerfKey() = default;
|
|
OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {})
|
|
: hash(hash), opType(opType.underlying()), attrs(attrs) {}
|
|
bool operator==(const OpPerfKey &rhs) const {
|
|
if (hash != rhs.hash)
|
|
return false;
|
|
if (opType != rhs.opType)
|
|
return false;
|
|
if (attrs != rhs.attrs)
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
// TODO: remove this function after we use unordered_map in PerfEngine
|
|
bool operator<(const OpPerfKey &rhs) const {
|
|
if (hash != rhs.hash)
|
|
return hash < rhs.hash;
|
|
if (opType != rhs.opType)
|
|
return opType < rhs.opType;
|
|
if (attrs.size() != rhs.attrs.size())
|
|
return attrs.size() < rhs.attrs.size();
|
|
for (size_t i = 0; i < attrs.size(); ++i)
|
|
if (attrs[i] != rhs.attrs[i])
|
|
return attrs[i] < rhs.attrs[i];
|
|
return false;
|
|
}
|
|
};
|
|
|
|
class GraphObj;
|
|
class OperatorObj : public Object {
|
|
friend class GraphObj;
|
|
|
|
protected:
|
|
OpType type;
|
|
TensorVec inputs;
|
|
TensorVec outputs;
|
|
vector<WRef<OperatorObj>> predecessors;
|
|
vector<WRef<OperatorObj>> successors;
|
|
|
|
public:
|
|
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
|
|
virtual optional<vector<Shape>> inferShape(const TensorVec &inputs) = 0;
|
|
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
|
|
/**
|
|
* @brief Constructs outputs (if requried) and check whether the operator is
|
|
* valid.
|
|
*
|
|
* @param graph If graph is not nullptr, outputs should be created in this
|
|
* function.
|
|
*/
|
|
bool checkValid(GraphObj *graph);
|
|
OpPerfKey getOpPerfKey() const;
|
|
/**
|
|
* @brief Hash operator attributes. Input and output shapes are not
|
|
* considered.
|
|
*/
|
|
HashType hash() const;
|
|
|
|
public:
|
|
public: // getter and setter
|
|
const TensorVec &getInputs() const { return inputs; }
|
|
const TensorVec &getOutputs() const { return outputs; }
|
|
Tensor getInputs(size_t i) const { return inputs.at(i); }
|
|
Tensor getOutput() const {
|
|
IT_ASSERT(outputs.size() == 1, "Unimplemented");
|
|
return outputs[0];
|
|
}
|
|
Tensor getOutput(size_t i) const {
|
|
IT_ASSERT(i < outputs.size(), "Index exceeded");
|
|
return outputs.at(i);
|
|
}
|
|
OpVec getPredecessors() const { return wrefs_to_refs(predecessors); }
|
|
OpVec getSuccessors() const { return wrefs_to_refs(successors); }
|
|
OpType getOpType() const { return type; }
|
|
// HACK: set correct data type
|
|
DataType getDType() const { return getInputs(0)->getDType(); }
|
|
DataType getOutDType() const { return getOutput()->getDType(); }
|
|
virtual int numInputs() const = 0;
|
|
virtual int numOutputs() const = 0;
|
|
|
|
/**
|
|
* @brief Clone this operator and replace its inputs and outputs.
|
|
*
|
|
* @param newInputs
|
|
* @param newOutputs
|
|
* @return Operator
|
|
*/
|
|
virtual Operator clone(const TensorVec &newInputs,
|
|
const TensorVec &newOutputs) const = 0;
|
|
|
|
protected:
|
|
optional<vector<Shape>> inferShape();
|
|
vector<DataType> inferDataType() const;
|
|
|
|
private:
|
|
/**
|
|
* @brief The returned vector includes operator attributes, such as paddings
|
|
* in Conv and transpose in Matmul. However, the input and output shapes are
|
|
* not taken into consideration.
|
|
*/
|
|
virtual vector<int> getOpAttrVector() const { IT_TODO_HALT(); }
|
|
/**
|
|
* @brief Besides operator attributes, the returned vector includes input
|
|
* and output shapes.
|
|
*/
|
|
virtual vector<int> getWorkloadVector() const { IT_TODO_HALT(); }
|
|
|
|
void addPredecessors(const Operator &op) { predecessors.emplace_back(op); }
|
|
void addSuccessors(const Operator &op) { successors.emplace_back(op); }
|
|
void removePredecessors(const Operator &op);
|
|
void removeSuccessors(const Operator &op);
|
|
void replaceInput(Tensor t1, Tensor t2);
|
|
};
|
|
|
|
#define OP_CLONE(OpObj) \
|
|
virtual Operator clone(const TensorVec &newInputs, \
|
|
const TensorVec &newOutputs) const override { \
|
|
auto op = infini::make_ref<OpObj>(*this); \
|
|
op->inputs = newInputs; \
|
|
op->outputs = newOutputs; \
|
|
op->predecessors.clear(); \
|
|
op->successors.clear(); \
|
|
IT_ASSERT(op->checkValid(nullptr)); \
|
|
return op; \
|
|
}
|
|
|
|
} // namespace infini
|
|
|
|
namespace std {
|
|
template <> struct hash<infini::OpPerfKey> {
|
|
size_t operator()(const infini::OpPerfKey &key) const { return key.hash; }
|
|
};
|
|
} // namespace std
|