diff --git a/include/core/common.h b/include/core/common.h index 365b40f4..8dd6a108 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -12,7 +12,7 @@ #include #include -namespace it { +namespace infini { using std::list; using std::map; using std::nullopt; @@ -46,8 +46,8 @@ using dtype = float; #define _IT_ASSERT_1(name) _IT_ASSERT_2(name, ""); #define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) -#define IT_TODO_HALT(...) IT_ASSERT(false, "Unimplemented") -#define IT_TODO_SKIP(...) puts("Unimplemented " __FILE__ ":" __LINE__) +#define IT_TODO_HALT() IT_ASSERT(false, "Unimplemented") +#define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__) // Other utilities @@ -56,4 +56,6 @@ template auto enum_to_underlying(T e) { return static_cast>(e); } -} // namespace it +double timeit(const std::function &func); + +} // namespace infini diff --git a/include/core/graph.h b/include/core/graph.h index 64db2330..2f72bb94 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -2,7 +2,7 @@ #include "core/operator.h" #include "core/tensor.h" -namespace it { +namespace infini { class GraphNode : public Object { protected: @@ -32,12 +32,9 @@ class GraphNode : public Object { void updateConnection(); void dataMalloc(); - // TODO - // bool compute(); - // TODO: move to another class // bool exportOnnx(const char *path); // bool importOnnx(const char *net); }; -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/include/core/kernel.h b/include/core/kernel.h index 74b01168..437884e2 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -3,22 +3,26 @@ #include "core/operator.h" #include "core/tensor.h" -namespace it { +namespace infini { -enum class Device { CPU = 1, CUDA }; +struct PerfRecord { + double time; // in milliseconds +}; class Kernel { public: Kernel() {} virtual ~Kernel() {} + virtual void compute(const Operator &op, + const PerfRecord &record) const = 0; + // This function call compute with a default record. virtual void compute(const Operator &op) const = 0; + // Tuning should be idempotent since it is called multiple times. + virtual PerfRecord tune(const Operator &op) const = 0; }; class KernelRegistry { - public: - using Key = std::tuple; - public: ~KernelRegistry() { for (auto &[k, v] : kernels) @@ -28,29 +32,29 @@ class KernelRegistry { static KernelRegistry instance; return instance; } - bool registerKernel(const Key &key, Kernel *kernel) { + bool registerKernel(const KernelAttrs &key, Kernel *kernel) { // TODO: kernels with priority IT_ASSERT(kernels.find(key) == kernels.end(), "Kernel already registered"); kernels.emplace(key, kernel); return true; } - Kernel *getKernel(Device device, OpType opType, DataType dataType) const { - return kernels.at(Key{device, opType, dataType}); + Kernel *getKernel(const KernelAttrs &kernelAttrs) const { + return kernels.at(kernelAttrs); } private: - std::map kernels; + std::map kernels; }; #define _REGISTER_KERNEL_1(device, opType, dataType, kernel, cnt) \ - namespace it { \ + namespace infini { \ static const bool _CAT(_register_kernel_, cnt) = \ KernelRegistry::getInstance().registerKernel( \ - KernelRegistry::Key{device, opType, dataType}, new kernel()); \ + KernelAttrs{device, opType, dataType}, new kernel()); \ } #define REGISTER_KERNEL(device, opType, dataType, kernel) \ _REGISTER_KERNEL_1(device, opType, dataType, kernel, __COUNTER__) -} // namespace it +} // namespace infini diff --git a/include/core/object.h b/include/core/object.h index 10a0d46d..0faec9f5 100644 --- a/include/core/object.h +++ b/include/core/object.h @@ -2,7 +2,7 @@ #include "core/common.h" #include "ref.h" -namespace it { +namespace infini { using GuidBaseType = int; @@ -51,4 +51,4 @@ inline std::ostream &operator<<(std::ostream &os, const Ref &obj) { return os; } -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/include/core/operator.h b/include/core/operator.h index 18b8e842..c8146378 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -1,7 +1,7 @@ #pragma once #include "core/tensor.h" -namespace it { +namespace infini { enum class OpType { Unknown = 0, @@ -37,9 +37,13 @@ enum class OpType { MemBound = 300, }; +enum class Device { CPU = 1, CUDA }; + +using KernelAttrs = std::tuple; + class OpRegistry { public: - std::string getOpName(OpType opType) { + static std::string getOpName(OpType opType) { #define FOP(op) \ case OpType::op: \ return #op @@ -90,6 +94,16 @@ enum class ActType { Tanh, }; +struct OpAttrs { + public: + virtual bool operator<(const OpAttrs &rhs) const { + IT_ASSERT(typeid(*this) == typeid(rhs), "OpAttrs type mismatch."); + // Empty OpAttrs are equal + return false; + } + virtual ~OpAttrs() {} +}; + class OperatorNode : public Object { public: protected: @@ -103,6 +117,7 @@ class OperatorNode : public Object { OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs) : type(opType), inputs(inputs), outputs(outputs) {} virtual vector computeShape() const = 0; + virtual OpAttrs getOpAttrs() const = 0; public: // check Op type bool isLinearOp() const; @@ -132,13 +147,25 @@ class OperatorNode : public Object { class MatmulNode : public OperatorNode { public: - struct MatmulArgs { + struct MatmulArgs : public OpAttrs { int b, m, n, k; // PET assume a row-major tensor layout. transA=false means default // dims, true means A should be transposed before matmul. This is in // oppsite to column-major BLAS. bool transA, transB; ActType act; + + MatmulArgs(int b, int m, int n, int k, bool transA, bool transB, + ActType act) + : b(b), m(m), n(n), k(k), transA(transA), transB(transB), act(act) { + } + + bool operator<(const OpAttrs &rhsGeneric) { + auto rhs = dynamic_cast(rhsGeneric); + return std::tie(b, m, n, k, transA, transB, act) < + std::tie(rhs.b, rhs.m, rhs.n, rhs.k, rhs.transA, rhs.transB, + rhs.act); + } }; private: @@ -162,6 +189,7 @@ class MatmulNode : public OperatorNode { bool getTransB() const { return args.transB; } MatmulArgs getArgs() const { return args; } + OpAttrs getOpAttrs() const override { return args; } private: // Q: whether to check the output? Since we can build an Op first and then @@ -170,4 +198,4 @@ class MatmulNode : public OperatorNode { bool checkValid(const TensorVec &inputs) const; }; -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/include/core/perf_engine.h b/include/core/perf_engine.h new file mode 100644 index 00000000..b55baf26 --- /dev/null +++ b/include/core/perf_engine.h @@ -0,0 +1,34 @@ +#pragma once +#include "core/graph.h" +#include "core/kernel.h" + +namespace infini { + +class PerfEngine { + public: + using Key = std::pair; + + private: + map data; + + public: + static PerfEngine &getInstance() { + static PerfEngine instance; + return instance; + } + + std::optional getPerfData(const Key &key) { + auto it = data.find(key); + if (it != data.end()) // find previous evaluating results + return data.at(key); + else + return std::nullopt; + } + + void setPerfData(const Key &key, const PerfRecord &record) { + IT_ASSERT(data.find(key) == data.end(), "Perf data already exist"); + data.emplace(key, record); + } +}; + +} // namespace infini \ No newline at end of file diff --git a/include/core/ref.h b/include/core/ref.h index 34546e02..f5ba4e89 100644 --- a/include/core/ref.h +++ b/include/core/ref.h @@ -3,7 +3,7 @@ #include #include -namespace it { +namespace infini { template using Ref = std::shared_ptr; template using WRef = std::weak_ptr; @@ -32,4 +32,4 @@ std::vector> get_wref_vec(const std::vector> &vec) { return wref_vec; } -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/include/core/run_enigne.h b/include/core/run_enigne.h index ee8ad644..1b4877a1 100644 --- a/include/core/run_enigne.h +++ b/include/core/run_enigne.h @@ -1,24 +1,26 @@ +#pragma once #include "core/graph.h" #include "core/kernel.h" +#include "core/perf_engine.h" -namespace it { +namespace infini { class RunEngine { + private: + Device device; + public: RunEngine(Device device) : device(device) {} ~RunEngine() {} - void run(Graph graph) const { - const auto &kernelRegistry = KernelRegistry::getInstance(); - for (auto &op : graph->getOperators()) { - // HACK: set correct data type - Kernel *kernel = kernelRegistry.getKernel(device, op->getOpType(), - DataType::Int32); - kernel->compute(op); - } - } + void run(const Graph &graph, bool tune = false, + bool profiling = false) const; + double getPerfTime(const Graph &graph, bool profiling = false) const; private: - Device device; + void printProfilingData(double totTime, + const std::map &opTime, + const std::map &opCnt) const; }; -} // namespace it \ No newline at end of file + +} // namespace infini \ No newline at end of file diff --git a/include/core/tensor.h b/include/core/tensor.h index eb9117ab..67544753 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,7 +1,7 @@ #pragma once #include "core/tensor_base.h" -namespace it { +namespace infini { // TODO: how to deal with this using ShapeElem = int; @@ -177,4 +177,4 @@ class TensorNode : public TensorBaseNode { // void printShape(); }; -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/include/core/tensor_base.h b/include/core/tensor_base.h index c8a5d8dd..da08e118 100644 --- a/include/core/tensor_base.h +++ b/include/core/tensor_base.h @@ -2,7 +2,7 @@ #include "core/object.h" #include "core/ref.h" -namespace it { +namespace infini { // class Tensor; class TensorBaseNode; @@ -267,4 +267,4 @@ class TensorBaseNode : public Object { // void printShape(); }; -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/src/core/common.cc b/src/core/common.cc new file mode 100644 index 00000000..d1c7fd40 --- /dev/null +++ b/src/core/common.cc @@ -0,0 +1,14 @@ +#include "core/common.h" +#include +#include + +namespace infini { + +double timeit(const std::function &func) { + auto start = std::chrono::high_resolution_clock::now(); + func(); + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start).count(); +} + +} // namespace infini \ No newline at end of file diff --git a/src/core/graph.cc b/src/core/graph.cc index 453cba7e..0f6fb180 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,6 +1,6 @@ #include "core/graph.h" -namespace it { +namespace infini { void GraphNode::updateConnection() { IT_TODO_HALT(); } @@ -17,4 +17,4 @@ void GraphNode::dataMalloc() { tensor->dataMalloc(); } -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/src/core/operator.cc b/src/core/operator.cc index 122a7dfd..7a9dd38e 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -1,6 +1,6 @@ #include "core/operator.h" -namespace it { +namespace infini { bool OperatorNode::isLinearOp() const { return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200; @@ -37,13 +37,9 @@ vector MatmulNode::computeShape() const { MatmulNode::MatmulNode(Tensor A, Tensor B, Tensor C, bool transA, bool transB, Tensor bias, ActType act) : OperatorNode(OpType::Matmul, {A, B, bias}, {C}), - args{.b = A->getDims()[0], - .m = transA ? A->getDims()[2] : A->getDims()[1], - .n = transB ? B->getDims()[1] : B->getDims()[2], - .k = transA ? A->getDims()[1] : A->getDims()[2], - .transA = transA, - .transB = transB, - .act = act} { + args(A->getDims()[0], transA ? A->getDims()[2] : A->getDims()[1], + transB ? B->getDims()[1] : B->getDims()[2], + transA ? A->getDims()[1] : A->getDims()[2], transA, transB, act) { IT_ASSERT(checkValid(inputs)); } @@ -78,4 +74,4 @@ bool MatmulNode::checkValid(const TensorVec &inputs) const { return true; } -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/src/core/run_engine.cc b/src/core/run_engine.cc new file mode 100644 index 00000000..155738e1 --- /dev/null +++ b/src/core/run_engine.cc @@ -0,0 +1,105 @@ +#include "core/run_enigne.h" +#include + +namespace infini { + +void RunEngine::run(const Graph &graph, bool tune, bool profiling) const { + if (!tune && profiling) + IT_TODO_HALT(); + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto perfEngine = PerfEngine::getInstance(); + // Statistics + double totalTime = 0; + std::map opTime; + std::map opCnt; + std::chrono::system_clock::time_point begin, end; + + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = + KernelAttrs{device, op->getOpType(), DataType::Int32}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()}; + std::optional perfData = perfEngine.getPerfData(perfKey); + + // If no record and disable tuning, run with the default argument + if (!perfData && !tune) { + kernel->compute(op); + continue; + } + + // TODO: The copy of record should be eliminated + PerfRecord record; + // Tune the kernel if there is no record + if (!perfData) { + record = kernel->tune(op); + perfEngine.setPerfData(perfKey, record); + } else + record = *perfData; + + if (!profiling) { + kernel->compute(op, *perfData); + continue; + } else { + double t = timeit([&]() { kernel->compute(op, *perfData); }); + op->print(); + printf(" op_time %lf\n", t); + totalTime += t; + opTime[op->getOpType()] += t; + opCnt[op->getOpType()]++; + } + } + if (profiling) + printProfilingData(totalTime, opTime, opCnt); +} + +double RunEngine::getPerfTime(const Graph &graph, bool profiling) const { + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto perfEngine = PerfEngine::getInstance(); + // Statistics + double totalTime = 0; + std::map opTime; + std::map opCnt; + + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = + KernelAttrs{device, op->getOpType(), DataType::Int32}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()}; + std::optional perfData = perfEngine.getPerfData(perfKey); + + PerfRecord record; + // Tune the kernel if there is no record + if (!perfData) { + record = kernel->tune(op); + perfEngine.setPerfData(perfKey, record); + } else + record = *perfData; + + double t = record.time; + totalTime += t; + if (profiling) { + op->print(); + printf(" op_time %lf\n", t); + opTime[op->getOpType()] += t; + opCnt[op->getOpType()]++; + } + } + if (profiling) + printProfilingData(totalTime, opTime, opCnt); + return totalTime; +} + +void RunEngine::printProfilingData(double totalTime, + const std::map &opTime, + const std::map &opCnt) const { + printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean"); + for (const auto &[type, t] : opTime) { + printf("%11s %3d %7.3f %7.1f %7.3f\n", + OpRegistry::getOpName(type).data(), opCnt.at(type), t, + t / totalTime * 100, t / opCnt.at(type)); + } +} + +} // namespace infini \ No newline at end of file diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 877e0396..18460986 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -1,5 +1,5 @@ #include -namespace it { +namespace infini { TensorNode::TensorNode(const Shape &shape, DataType dtype) : TensorBaseNode(shape.size(), dtype), shape(shape) {} @@ -87,4 +87,4 @@ bool TensorNode::equalData(const Tensor &rhs) const { return true; } -}; // namespace it \ No newline at end of file +}; // namespace infini \ No newline at end of file diff --git a/src/core/tensor_base.cc b/src/core/tensor_base.cc index 3275a84f..72297ce0 100644 --- a/src/core/tensor_base.cc +++ b/src/core/tensor_base.cc @@ -1,9 +1,9 @@ #include -namespace it { +namespace infini { TensorBaseNode::TensorBaseNode(int dim, DataType dtype) : dim(dim), dtype(dtype) {} VType TensorBaseNode::getData(size_t offset) const { return data[offset]; } -}; // namespace it \ No newline at end of file +}; // namespace infini \ No newline at end of file diff --git a/src/kerels/cpu/matmul.cc b/src/kerels/cpu/matmul.cc index 7f53dfb3..bc850730 100644 --- a/src/kerels/cpu/matmul.cc +++ b/src/kerels/cpu/matmul.cc @@ -1,6 +1,6 @@ #include "core/kernel.h" -namespace it { +namespace infini { template class NaiveMatmul : public Kernel { void compute(const Operator &_op) const override { @@ -14,12 +14,21 @@ template class NaiveMatmul : public Kernel { const int M = args.m, N = args.n, K = args.k; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { + C[i * N + j] = 0; for (int k = 0; k < K; k++) { C[i * N + j] += A[i * K + k] * B[k * N + j]; } } } } + + void compute(const Operator &op, const PerfRecord &record) const override { + compute(op); + } + + PerfRecord tune(const Operator &op) const override { + return PerfRecord{.time = timeit([this, &op]() { compute(op); })}; + } }; REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32, @@ -27,4 +36,4 @@ REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32, REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32, NaiveMatmul); -} // namespace it \ No newline at end of file +} // namespace infini \ No newline at end of file diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index 206a7a7f..067b83c7 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -2,9 +2,9 @@ #include "core/run_enigne.h" #include "test.h" -namespace it { +namespace infini { -TEST(Graph, build) { +TEST(Graph, build_and_run) { Graph g = make_ref(); Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); @@ -21,4 +21,25 @@ TEST(Graph, build) { EXPECT_TRUE(o0->equalData(ans)); } -} // namespace it \ No newline at end of file +TEST(Graph, perf_engine) { + Graph g = make_ref(); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32); + Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32); + g->dataMalloc(); + i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); + w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data()); + g->addOp(make_ref(i0, w0, o0)); + RunEngine(Device::CPU).run(g, true, true); + double perfTime = RunEngine(Device::CPU).getPerfTime(g); + // The example matmul takes 0.0036ms with one core + EXPECT_GT(perfTime, 0); + EXPECT_LT(perfTime, 0.01); + // check answer + auto ans = make_ref(Shape{1, 2, 4}, DataType::Int32); + ans->dataMalloc(); + ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}.data()); + EXPECT_TRUE(o0->equalData(ans)); +} + +} // namespace infini \ No newline at end of file