Chore: rename getOpAttrs to getOpPerfKey

This commit is contained in:
Liyan Zheng 2022-08-09 15:34:28 +08:00
parent 8b685ae4a6
commit 2054b0eda4
5 changed files with 8 additions and 9 deletions

View File

@ -141,7 +141,7 @@ class OperatorNode : public Object {
OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs) OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs)
: type(opType), inputs(inputs), outputs(outputs) {} : type(opType), inputs(inputs), outputs(outputs) {}
virtual vector<Shape> computeShape() const = 0; virtual vector<Shape> computeShape() const = 0;
virtual OpPerfKey getOpAttrs() const = 0; virtual OpPerfKey getOpPerfKey() const = 0;
public: // check Op type public: // check Op type
bool isLinearOp() const; bool isLinearOp() const;

View File

@ -35,7 +35,7 @@ class MatmulNode : public OperatorNode {
int getK() const { return k; } int getK() const { return k; }
HashType hashWithShape() const override; HashType hashWithShape() const override;
OpPerfKey getOpAttrs() const override; OpPerfKey getOpPerfKey() const override;
private: private:
// Q: whether to check the output? Since we can build an Op first and then // Q: whether to check the output? Since we can build an Op first and then

View File

@ -19,7 +19,7 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
auto kernelAttrs = auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Int32}; KernelAttrs{device, op->getOpType(), DataType::Int32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey); std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
// If no record and disable tuning, run with the default argument // If no record and disable tuning, run with the default argument
@ -66,7 +66,7 @@ double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
auto kernelAttrs = auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Int32}; KernelAttrs{device, op->getOpType(), DataType::Int32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()}; auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey); std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
PerfRecord record; PerfRecord record;

View File

@ -4,13 +4,14 @@
namespace infini { namespace infini {
template <typename T> class NaiveMatmul : public Kernel { template <typename T> class NaiveMatmul : public Kernel {
void compute(const Operator &_op) const override { void compute(const Operator &_op, const PerfRecord &record) const override {
auto op = as<MatmulNode>(_op); auto op = as<MatmulNode>(_op);
T *A = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get()); T *A = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get());
T *B = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get()); T *B = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get());
T *C = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get()); T *C = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get());
IT_ASSERT(op->getTransA() == false && op->getTransB() == false); IT_ASSERT(op->getTransA() == false && op->getTransB() == false);
IT_ASSERT(op->getAct() == ActType::None); IT_ASSERT(op->getAct() == ActType::None);
IT_ASSERT(op->getB() == 1);
const int M = op->getM(), N = op->getN(), K = op->getK(); const int M = op->getM(), N = op->getN(), K = op->getK();
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
@ -22,9 +23,7 @@ template <typename T> class NaiveMatmul : public Kernel {
} }
} }
void compute(const Operator &op, const PerfRecord &record) const override { void compute(const Operator &op) const override { compute(op, {}); }
compute(op);
}
PerfRecord tune(const Operator &op) const override { PerfRecord tune(const Operator &op) const override {
return PerfRecord{.time = timeit([this, &op]() { compute(op); })}; return PerfRecord{.time = timeit([this, &op]() { compute(op); })};

View File

@ -49,7 +49,7 @@ HashType MatmulNode::hashWithShape() const {
return b + m + n + k + transA + transB + enum_to_underlying(act); return b + m + n + k + transA + transB + enum_to_underlying(act);
} }
OpPerfKey MatmulNode::getOpAttrs() const { OpPerfKey MatmulNode::getOpPerfKey() const {
return OpPerfKey(hashWithShape(), type, return OpPerfKey(hashWithShape(), type,
{b, m, n, k, transA, transB, enum_to_underlying(act)}); {b, m, n, k, transA, transB, enum_to_underlying(act)});
} }