diff --git a/include/core/operator.h b/include/core/operator.h index 9d04aed2..a24f240d 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -141,7 +141,7 @@ class OperatorNode : public Object { OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs) : type(opType), inputs(inputs), outputs(outputs) {} virtual vector computeShape() const = 0; - virtual OpPerfKey getOpAttrs() const = 0; + virtual OpPerfKey getOpPerfKey() const = 0; public: // check Op type bool isLinearOp() const; diff --git a/include/operators/matmul.h b/include/operators/matmul.h index 405b3f76..b94dabe0 100644 --- a/include/operators/matmul.h +++ b/include/operators/matmul.h @@ -35,7 +35,7 @@ class MatmulNode : public OperatorNode { int getK() const { return k; } HashType hashWithShape() const override; - OpPerfKey getOpAttrs() const override; + OpPerfKey getOpPerfKey() const override; private: // Q: whether to check the output? Since we can build an Op first and then diff --git a/src/core/run_engine.cc b/src/core/run_engine.cc index 155738e1..ba6878bc 100644 --- a/src/core/run_engine.cc +++ b/src/core/run_engine.cc @@ -19,7 +19,7 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const { auto kernelAttrs = KernelAttrs{device, op->getOpType(), DataType::Int32}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); - auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()}; + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); // If no record and disable tuning, run with the default argument @@ -66,7 +66,7 @@ double RunEngine::getPerfTime(const Graph &graph, bool profiling) const { auto kernelAttrs = KernelAttrs{device, op->getOpType(), DataType::Int32}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); - auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()}; + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); PerfRecord record; diff --git a/src/kerels/cpu/matmul.cc b/src/kerels/cpu/matmul.cc index e8ae5c7e..84fa53a3 100644 --- a/src/kerels/cpu/matmul.cc +++ b/src/kerels/cpu/matmul.cc @@ -4,13 +4,14 @@ namespace infini { template class NaiveMatmul : public Kernel { - void compute(const Operator &_op) const override { + void compute(const Operator &_op, const PerfRecord &record) const override { auto op = as(_op); T *A = reinterpret_cast(op->getInputs(0)->getDataPtr().get()); T *B = reinterpret_cast(op->getInputs(1)->getDataPtr().get()); T *C = reinterpret_cast(op->getOutput()->getDataPtr().get()); IT_ASSERT(op->getTransA() == false && op->getTransB() == false); IT_ASSERT(op->getAct() == ActType::None); + IT_ASSERT(op->getB() == 1); const int M = op->getM(), N = op->getN(), K = op->getK(); for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { @@ -22,9 +23,7 @@ template class NaiveMatmul : public Kernel { } } - void compute(const Operator &op, const PerfRecord &record) const override { - compute(op); - } + void compute(const Operator &op) const override { compute(op, {}); } PerfRecord tune(const Operator &op) const override { return PerfRecord{.time = timeit([this, &op]() { compute(op); })}; diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 2f9666a2..9f15bc5c 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -49,7 +49,7 @@ HashType MatmulNode::hashWithShape() const { return b + m + n + k + transA + transB + enum_to_underlying(act); } -OpPerfKey MatmulNode::getOpAttrs() const { +OpPerfKey MatmulNode::getOpPerfKey() const { return OpPerfKey(hashWithShape(), type, {b, m, n, k, transA, transB, enum_to_underlying(act)}); }