From 90eb9d05a8354279db7c8355f8864a9dd91f6113 Mon Sep 17 00:00:00 2001 From: Anmuliar <1260818462@qq.com> Date: Thu, 22 Sep 2022 15:34:34 +0800 Subject: [PATCH] Json perfrecord (#32) Added perfengine serialization&deserialization and corresponding test case. * Add: perfrecord json representation. * Add: perfrecord virtual func. to_json&from_json. * Add: perfengine serilization and deserilization. * Modify: tune func type to supp derived struct serilization. * Fix: structure after rebase * Chore: Remove empty line in conv.h Co-authored-by: wcz112 Co-authored-by: Liyan Zheng Co-authored-by: zhengly123 --- include/core/data_type.h | 3 ++ include/core/kernel.h | 54 +++++++++++++++++++++++-- include/core/operator.h | 3 ++ include/core/perf_engine.h | 20 ++++++++-- include/core/runtime.h | 1 + include/utils/dataloader.h | 1 - src/core/perf_engine.cc | 52 ++++++++++++++++++++++++ src/core/runtime.cc | 9 ++--- src/cuda/cuda_runtime.cc | 9 +++-- src/kernels/cuda/conv.cc | 46 +++++++++++++++------- src/kernels/cuda/matmul.cc | 41 ++++++++++++------- test/kernels/cuda/test_perfengine.cc | 59 ++++++++++++++++++++++++++++ 12 files changed, 251 insertions(+), 47 deletions(-) create mode 100644 src/core/perf_engine.cc create mode 100644 test/kernels/cuda/test_perfengine.cc diff --git a/include/core/data_type.h b/include/core/data_type.h index 173600fb..654ce1ce 100644 --- a/include/core/data_type.h +++ b/include/core/data_type.h @@ -13,6 +13,9 @@ class DataType { int index; public: + // FIXME: default ctor should be deleted but json requires it. Solution: + // https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types + DataType() = default; constexpr DataType(int index) : index(index) {} bool operator==(const DataType &rhs) const { return index == rhs.index; } bool operator<(const DataType &rhs) const { return index < rhs.index; } diff --git a/include/core/kernel.h b/include/core/kernel.h index 391e57e2..9415d496 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -2,7 +2,9 @@ #include "core/common.h" #include "core/operator.h" #include "core/tensor.h" - +#include +#include +using json = nlohmann::json; namespace infini { class RuntimeObj; // Forward declaration for Kernel::compute @@ -10,12 +12,19 @@ class RuntimeObj; // Forward declaration for Kernel::compute struct PerfRecordObj { PerfRecordObj(){}; PerfRecordObj(double time) : time(time){}; - virtual ~PerfRecordObj() {} - + virtual ~PerfRecordObj(){}; double time = 0; // in milliseconds + virtual void to_json(json &j) { + j["type"] = 0; + j["data"] = time; + } + static Ref from_json(const json &j) { + PerfRecordObj tmp; + tmp.time = j["data"].get(); + return make_ref(tmp); + } }; using PerfRecord = Ref; - class Kernel { public: Kernel() {} @@ -39,6 +48,33 @@ class Kernel { const RuntimeObj *context) const = 0; }; +class PerfRecordRegistry { + + private: + std::map> perfrecords; + int nperfrecord = 0; + + public: + ~PerfRecordRegistry() = default; + static PerfRecordRegistry &getInstance() { + static PerfRecordRegistry instance; + return instance; + } + bool + registerPerfRecord(const int type, + std::function constructor) { + IT_ASSERT(perfrecords.find(type) == perfrecords.end(), + "Constructor already registered"); + perfrecords.emplace(type, constructor); + nperfrecord++; + return true; + } + const std::function & + getConstructor(const int type) const { + return perfrecords.at(type); + } +}; + class KernelRegistry { public: using KernelRecord = @@ -100,3 +136,13 @@ class CpuKernelWithoutConfig : public Kernel { #define REGISTER_KERNEL(device, opType, dataType, kernel, name) \ _REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__) + +#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \ + namespace infini { \ + static const bool _CAT(_register_constructor_, cnt) = \ + PerfRecordRegistry::getInstance().registerPerfRecord(type, \ + constructor); \ + } + +#define REGISTER_CONSTRUCTOR(type, constructor) \ + _REGISTER_CONSTRUCTOR_1(type, constructor, __COUNTER__) diff --git a/include/core/operator.h b/include/core/operator.h index 8e1dd593..b2db95ce 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -107,6 +107,9 @@ struct OpPerfKey { vector attrs; public: + // FIXME: default ctor should be deleted but json requires it. Solution: + // https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types + OpPerfKey() = default; OpPerfKey(HashType hash, OpType opType, vector attrs = {}) : hash(hash), opType(opType), attrs(attrs) {} bool operator==(const OpPerfKey &rhs) const { diff --git a/include/core/perf_engine.h b/include/core/perf_engine.h index 4689cbcf..58659134 100644 --- a/include/core/perf_engine.h +++ b/include/core/perf_engine.h @@ -1,7 +1,8 @@ #pragma once #include "core/graph.h" #include "core/kernel.h" - +#include +using json = nlohmann::json; namespace infini { class PerfEngine { @@ -23,18 +24,29 @@ class PerfEngine { return instance; } - std::optional getPerfData(const Key &key) { + /** + * @brief Get the Perf Data object + * + * @return PerfRecord nullptr if no record is fnoud. + */ + PerfRecord getPerfData(const Key &key) { auto it = data.find(key); if (it != data.end()) // find previous evaluating results return data.at(key); else - return std::nullopt; + return nullptr; } - void setPerfData(const Key &key, const PerfRecord &record) { + void setPerfData(const Key &key, PerfRecord record) { IT_ASSERT(data.find(key) == data.end(), "Perf data already exist"); data.emplace(key, record); } + map get_data() { return data; } + void set_data(map data) { this->data = data; } + void savePerfEngineData(std::string file_path); + void loadPerfEngineData(std::string file_path); }; +void to_json(json &j, const PerfEngine &p); +void from_json(const json &j, PerfEngine &p); } // namespace infini \ No newline at end of file diff --git a/include/core/runtime.h b/include/core/runtime.h index 9df2f9c3..c73133c3 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -2,6 +2,7 @@ #include "core/common.h" #include "core/ref.h" #include + namespace infini { /***************** Forward declaration begin *****************/ diff --git a/include/utils/dataloader.h b/include/utils/dataloader.h index 8b32d93a..58e69836 100644 --- a/include/utils/dataloader.h +++ b/include/utils/dataloader.h @@ -6,5 +6,4 @@ namespace infini { void loadTensorData(TensorObj *tensor, std::string file_path); void saveTensorData(TensorObj *tensor, std::string file_path); - } // namespace infini diff --git a/src/core/perf_engine.cc b/src/core/perf_engine.cc new file mode 100644 index 00000000..ecf97f66 --- /dev/null +++ b/src/core/perf_engine.cc @@ -0,0 +1,52 @@ +#include "core/perf_engine.h" +#include +namespace infini { + +REGISTER_CONSTRUCTOR(0, PerfRecordObj::from_json); + +void PerfEngine::savePerfEngineData(std::string file_path) { + std::ofstream fileout(file_path, + std::ios::out | std::ios::trunc | std::ios::binary); + json t = this->getInstance(); + fileout << t << std::endl; + fileout.close(); +} + +void PerfEngine::loadPerfEngineData(std::string file_path) { + std::ifstream filein(file_path, std::ios::in | std::ios::binary); + string t; + filein >> t; + json j = json::parse(t); + from_json(j, this->getInstance()); + filein.close(); +} + +/* json register should in the common namespace with corresponding type*/ +void to_json(json &j, const OpPerfKey &p) { + j = json{{"hashType", p.hash}, {"opType", p.opType}, {"attrs", p.attrs}}; +} +void from_json(const json &j, OpPerfKey &p) { + j.at("hashType").get_to(p.hash); + j.at("opType").get_to(p.opType); + j.at("attrs").get_to(p.attrs); +} +void to_json(json &j, const DataType &p) { + j = p.toString() == "Float32" ? 0 : 1; +} +void from_json(const json &j, DataType &p) { p = DataType(j.get()); } +void to_json(json &j, const PerfRecord &p) { p->to_json(j); } +void from_json(const json &j, PerfRecord &p) { + int type = j["type"].get(); + p = PerfRecordRegistry::getInstance().getConstructor(type)(j); +} + +void to_json(json &j, const PerfEngine &p) { + auto &x = p.getInstance(); + j["data"] = x.get_data(); +} +void from_json(const json &j, PerfEngine &p) { + auto tmp = j["data"].get>(); + p.set_data(tmp); +} + +} // namespace infini \ No newline at end of file diff --git a/src/core/runtime.cc b/src/core/runtime.cc index faf3d0e6..4d02f5ba 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -4,7 +4,6 @@ #include "core/perf_engine.h" #include #include - namespace infini { void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { @@ -21,7 +20,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; - std::optional perfData = perfEngine.getPerfData(perfKey); + auto perfData = perfEngine.getPerfData(perfKey); // If no record and disable tuning, run with the default argument if (!perfData && !tune) { @@ -38,7 +37,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { record = kernel->tune(op, this); perfEngine.setPerfData(perfKey, record); } else - record = *perfData; + record = perfData; if (!profiling) { kernel->compute(op, record, this); @@ -69,7 +68,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const { auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; - std::optional perfData = perfEngine.getPerfData(perfKey); + auto perfData = perfEngine.getPerfData(perfKey); PerfRecord record; // Tune the kernel if there is no record @@ -77,7 +76,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const { record = kernel->tune(op, this); perfEngine.setPerfData(perfKey, record); } else - record = *perfData; + record = perfData; double t = record->time; totalTime += t; diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index fbfdbbbe..e8369535 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -1,7 +1,8 @@ #include "cuda/cuda_runtime.h" #include "core/kernel.h" #include "core/perf_engine.h" - +#include "operators/conv.h" +#include "operators/matmul.h" namespace infini { void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, @@ -17,7 +18,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, KernelAttrs{device, op->getOpType(), DataType::Float32}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; - std::optional perfData = perfEngine.getPerfData(perfKey); + auto perfData = perfEngine.getPerfData(perfKey); if (!perfData && !tune) { kernel->compute(op, this); continue; @@ -28,10 +29,10 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, record = kernel->tune(op, this); perfEngine.setPerfData(perfKey, record); } else - record = *perfData; - + record = perfData; double t = record->time; totalTime += t; + json j; if (profiling) { double t = timeit([&]() { kernel->compute(op, record, this); }, diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index f09d10f4..e80a0b1e 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -7,30 +7,47 @@ #include namespace infini { -static constexpr int N_ALGO = 8; -static constexpr cudnnConvolutionFwdAlgo_t ALGOS[N_ALGO] = { - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, - CUDNN_CONVOLUTION_FWD_ALGO_FFT, - CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; -static constexpr int N_MODE = 2; -static constexpr cudnnConvolutionMode_t MODES[N_MODE] = { - CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION}; - struct ConvCuDnnPerfRecordObj : public PerfRecordObj { int algo = 0; // cudnnConvolutionFwdAlgo_t int mode = 1; size_t workspaceSize = 100000; bool fuseAct = false; + void to_json(json &j) override { + j["type"] = 1; + j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize); + } + static PerfRecord from_json(const json &j) { + ConvCuDnnPerfRecordObj tmp; + auto [Algo, Mode, FuseAct, Time, WorkspaceSize] = + j["data"].get>(); + tmp.algo = Algo; + tmp.mode = Mode; + tmp.fuseAct = FuseAct; + tmp.time = Time; + tmp.workspaceSize = WorkspaceSize; + return make_ref(tmp); + } }; + using ConvCuDnnPerfRecord = Ref; class convCudnn : public Kernel { + static constexpr int N_ALGO = 8; + static constexpr int N_MODE = 2; + static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = { + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; + + static constexpr cudnnConvolutionMode_t MODES[2] = { + CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION}; + std::tuple -#include namespace infini { -struct MatmulCudnnPerfRecordObj : public PerfRecordObj { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + +struct MatmulCublasPerfRecordObj : public PerfRecordObj { + int algo = CUBLAS_GEMM_DEFAULT; + void to_json(json &j) override { + j["type"] = 2; + j["data"] = std::make_pair(algo, time); + } + static PerfRecord from_json(const json &j) { + MatmulCublasPerfRecordObj tmp; + auto pr = j["data"].get>(); + tmp.algo = pr.first; + tmp.time = pr.second; + return make_ref(tmp); + } }; -using MatmulCudnnPerfRecord = Ref; + constexpr int N_ALGO = 24; constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = { CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2, @@ -20,7 +30,6 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = { CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20, CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23, }; - class matmulCublas : public Kernel { bool do_compute(const Operator &_op, const PerfRecord &_record, const RuntimeObj *_context) const { @@ -29,7 +38,7 @@ class matmulCublas : public Kernel { void *const inAData = (op->getInputs(0)->getRawDataPtr()); void *const inBData = (op->getInputs(1)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); - auto record = as(_record); + auto record = as(_record); const auto [b, m, n, k] = op->getBMNK(); auto opA = @@ -44,12 +53,13 @@ class matmulCublas : public Kernel { stat = cublasGemmStridedBatchedEx( context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta, - outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record->algo); + outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, + (cublasGemmAlgo_t)record->algo); } else { - stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k, - &alpha, inBData, CUDA_R_32F, ldb, inAData, - CUDA_R_32F, lda, &beta, outData, CUDA_R_32F, - ldc, CUDA_R_32F, record->algo); + stat = cublasGemmEx( + context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData, + CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData, + CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo); } return (stat == CUBLAS_STATUS_SUCCESS); } @@ -61,7 +71,7 @@ class matmulCublas : public Kernel { void compute(const Operator &op, const RuntimeObj *context) const override { auto record = - make_ref(); // use default record; + make_ref(); // use default record; compute(op, record, context); } @@ -69,10 +79,10 @@ class matmulCublas : public Kernel { const RuntimeObj *_context) const override { auto context = dynamic_cast(_context); auto op = as(_op); - auto ret = make_ref(); + auto ret = make_ref(); ret->time = std::numeric_limits::max(); for (int i = 0; i < N_ALGO; i++) { - auto rcd = make_ref(); + auto rcd = make_ref(); rcd->algo = ALGOS[i]; if (!do_compute(_op, rcd, _context)) continue; @@ -91,4 +101,5 @@ class matmulCublas : public Kernel { REGISTER_KERNEL(Device::CUDA, OpType::Matmul, DataType::Float32, matmulCublas, "Matmul_cuBLAS_CUDA_Float32"); +REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json); }; // namespace infini \ No newline at end of file diff --git a/test/kernels/cuda/test_perfengine.cc b/test/kernels/cuda/test_perfengine.cc new file mode 100644 index 00000000..99559734 --- /dev/null +++ b/test/kernels/cuda/test_perfengine.cc @@ -0,0 +1,59 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" +#include "operators/matmul.h" +#include "test.h" + +namespace infini { + +TEST(PerfEngine, save_and_load) { + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); + + auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32); + auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32); + gCpu->dataMalloc(); + ACpu->setData(IncrementalGenerator()); + BCpu->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + + auto ACuda = gCuda->cloneTensor(ACpu); + auto BCuda = gCuda->cloneTensor(BCpu); + auto matmul = gCuda->addOp(ACuda, BCuda, nullptr); + + gCuda->dataMalloc(); + cudaRuntime->run(gCuda, true); + auto &perfEngine = PerfEngine::getInstance(); + + json j0 = perfEngine; + std::cout << "PerfEngine saveed:" << std::endl; + std::cout << j0 << std::endl; + perfEngine.savePerfEngineData("test.json"); + perfEngine.loadPerfEngineData("test.json"); + json j1 = perfEngine; + std::cout << "PerfEngine loaded:" << std::endl; + std::cout << j1 << std::endl; + EXPECT_TRUE(j0 == j1); +} +} // namespace infini \ No newline at end of file