forked from jiuyuan/InfiniTensor
Json perfrecord (#32)
Added perfengine serialization&deserialization and corresponding test case. * Add: perfrecord json representation. * Add: perfrecord virtual func. to_json&from_json. * Add: perfengine serilization and deserilization. * Modify: tune func type to supp derived struct serilization. * Fix: structure after rebase * Chore: Remove empty line in conv.h Co-authored-by: wcz112 <wcz19@mails.tsinghua.edu.cn> Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com> Co-authored-by: zhengly123 <zhengly123@outlook.com>
This commit is contained in:
parent
9032cbb973
commit
90eb9d05a8
|
@ -13,6 +13,9 @@ class DataType {
|
|||
int index;
|
||||
|
||||
public:
|
||||
// FIXME: default ctor should be deleted but json requires it. Solution:
|
||||
// https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types
|
||||
DataType() = default;
|
||||
constexpr DataType(int index) : index(index) {}
|
||||
bool operator==(const DataType &rhs) const { return index == rhs.index; }
|
||||
bool operator<(const DataType &rhs) const { return index < rhs.index; }
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
#include "core/common.h"
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
#include <functional>
|
||||
#include <nlohmann/json.hpp>
|
||||
using json = nlohmann::json;
|
||||
namespace infini {
|
||||
|
||||
class RuntimeObj; // Forward declaration for Kernel::compute
|
||||
|
@ -10,12 +12,19 @@ class RuntimeObj; // Forward declaration for Kernel::compute
|
|||
struct PerfRecordObj {
|
||||
PerfRecordObj(){};
|
||||
PerfRecordObj(double time) : time(time){};
|
||||
virtual ~PerfRecordObj() {}
|
||||
|
||||
virtual ~PerfRecordObj(){};
|
||||
double time = 0; // in milliseconds
|
||||
virtual void to_json(json &j) {
|
||||
j["type"] = 0;
|
||||
j["data"] = time;
|
||||
}
|
||||
static Ref<PerfRecordObj> from_json(const json &j) {
|
||||
PerfRecordObj tmp;
|
||||
tmp.time = j["data"].get<int>();
|
||||
return make_ref<PerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
using PerfRecord = Ref<PerfRecordObj>;
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() {}
|
||||
|
@ -39,6 +48,33 @@ class Kernel {
|
|||
const RuntimeObj *context) const = 0;
|
||||
};
|
||||
|
||||
class PerfRecordRegistry {
|
||||
|
||||
private:
|
||||
std::map<int, std::function<PerfRecord(const json &)>> perfrecords;
|
||||
int nperfrecord = 0;
|
||||
|
||||
public:
|
||||
~PerfRecordRegistry() = default;
|
||||
static PerfRecordRegistry &getInstance() {
|
||||
static PerfRecordRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
bool
|
||||
registerPerfRecord(const int type,
|
||||
std::function<PerfRecord(const json &)> constructor) {
|
||||
IT_ASSERT(perfrecords.find(type) == perfrecords.end(),
|
||||
"Constructor already registered");
|
||||
perfrecords.emplace(type, constructor);
|
||||
nperfrecord++;
|
||||
return true;
|
||||
}
|
||||
const std::function<PerfRecord(const json &)> &
|
||||
getConstructor(const int type) const {
|
||||
return perfrecords.at(type);
|
||||
}
|
||||
};
|
||||
|
||||
class KernelRegistry {
|
||||
public:
|
||||
using KernelRecord =
|
||||
|
@ -100,3 +136,13 @@ class CpuKernelWithoutConfig : public Kernel {
|
|||
|
||||
#define REGISTER_KERNEL(device, opType, dataType, kernel, name) \
|
||||
_REGISTER_KERNEL_1(device, opType, dataType, kernel, name, __COUNTER__)
|
||||
|
||||
#define _REGISTER_CONSTRUCTOR_1(type, constructor, cnt) \
|
||||
namespace infini { \
|
||||
static const bool _CAT(_register_constructor_, cnt) = \
|
||||
PerfRecordRegistry::getInstance().registerPerfRecord(type, \
|
||||
constructor); \
|
||||
}
|
||||
|
||||
#define REGISTER_CONSTRUCTOR(type, constructor) \
|
||||
_REGISTER_CONSTRUCTOR_1(type, constructor, __COUNTER__)
|
||||
|
|
|
@ -107,6 +107,9 @@ struct OpPerfKey {
|
|||
vector<int> attrs;
|
||||
|
||||
public:
|
||||
// FIXME: default ctor should be deleted but json requires it. Solution:
|
||||
// https://github.com/nlohmann/json#how-can-i-use-get-for-non-default-constructiblenon-copyable-types
|
||||
OpPerfKey() = default;
|
||||
OpPerfKey(HashType hash, OpType opType, vector<int> attrs = {})
|
||||
: hash(hash), opType(opType), attrs(attrs) {}
|
||||
bool operator==(const OpPerfKey &rhs) const {
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
#pragma once
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
using json = nlohmann::json;
|
||||
namespace infini {
|
||||
|
||||
class PerfEngine {
|
||||
|
@ -23,18 +24,29 @@ class PerfEngine {
|
|||
return instance;
|
||||
}
|
||||
|
||||
std::optional<PerfRecord> getPerfData(const Key &key) {
|
||||
/**
|
||||
* @brief Get the Perf Data object
|
||||
*
|
||||
* @return PerfRecord nullptr if no record is fnoud.
|
||||
*/
|
||||
PerfRecord getPerfData(const Key &key) {
|
||||
auto it = data.find(key);
|
||||
if (it != data.end()) // find previous evaluating results
|
||||
return data.at(key);
|
||||
else
|
||||
return std::nullopt;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void setPerfData(const Key &key, const PerfRecord &record) {
|
||||
void setPerfData(const Key &key, PerfRecord record) {
|
||||
IT_ASSERT(data.find(key) == data.end(), "Perf data already exist");
|
||||
data.emplace(key, record);
|
||||
}
|
||||
map<Key, PerfRecord> get_data() { return data; }
|
||||
void set_data(map<Key, PerfRecord> data) { this->data = data; }
|
||||
void savePerfEngineData(std::string file_path);
|
||||
void loadPerfEngineData(std::string file_path);
|
||||
};
|
||||
void to_json(json &j, const PerfEngine &p);
|
||||
void from_json(const json &j, PerfEngine &p);
|
||||
|
||||
} // namespace infini
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/common.h"
|
||||
#include "core/ref.h"
|
||||
#include <memory>
|
||||
|
||||
namespace infini {
|
||||
|
||||
/***************** Forward declaration begin *****************/
|
||||
|
|
|
@ -6,5 +6,4 @@ namespace infini {
|
|||
|
||||
void loadTensorData(TensorObj *tensor, std::string file_path);
|
||||
void saveTensorData(TensorObj *tensor, std::string file_path);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#include "core/perf_engine.h"
|
||||
#include <fstream>
|
||||
namespace infini {
|
||||
|
||||
REGISTER_CONSTRUCTOR(0, PerfRecordObj::from_json);
|
||||
|
||||
void PerfEngine::savePerfEngineData(std::string file_path) {
|
||||
std::ofstream fileout(file_path,
|
||||
std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
json t = this->getInstance();
|
||||
fileout << t << std::endl;
|
||||
fileout.close();
|
||||
}
|
||||
|
||||
void PerfEngine::loadPerfEngineData(std::string file_path) {
|
||||
std::ifstream filein(file_path, std::ios::in | std::ios::binary);
|
||||
string t;
|
||||
filein >> t;
|
||||
json j = json::parse(t);
|
||||
from_json(j, this->getInstance());
|
||||
filein.close();
|
||||
}
|
||||
|
||||
/* json register should in the common namespace with corresponding type*/
|
||||
void to_json(json &j, const OpPerfKey &p) {
|
||||
j = json{{"hashType", p.hash}, {"opType", p.opType}, {"attrs", p.attrs}};
|
||||
}
|
||||
void from_json(const json &j, OpPerfKey &p) {
|
||||
j.at("hashType").get_to(p.hash);
|
||||
j.at("opType").get_to(p.opType);
|
||||
j.at("attrs").get_to(p.attrs);
|
||||
}
|
||||
void to_json(json &j, const DataType &p) {
|
||||
j = p.toString() == "Float32" ? 0 : 1;
|
||||
}
|
||||
void from_json(const json &j, DataType &p) { p = DataType(j.get<int>()); }
|
||||
void to_json(json &j, const PerfRecord &p) { p->to_json(j); }
|
||||
void from_json(const json &j, PerfRecord &p) {
|
||||
int type = j["type"].get<int>();
|
||||
p = PerfRecordRegistry::getInstance().getConstructor(type)(j);
|
||||
}
|
||||
|
||||
void to_json(json &j, const PerfEngine &p) {
|
||||
auto &x = p.getInstance();
|
||||
j["data"] = x.get_data();
|
||||
}
|
||||
void from_json(const json &j, PerfEngine &p) {
|
||||
auto tmp = j["data"].get<map<PerfEngine::Key, PerfRecord>>();
|
||||
p.set_data(tmp);
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -4,7 +4,6 @@
|
|||
#include "core/perf_engine.h"
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
|
||||
namespace infini {
|
||||
|
||||
void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||
|
@ -21,7 +20,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
||||
// If no record and disable tuning, run with the default argument
|
||||
if (!perfData && !tune) {
|
||||
|
@ -38,7 +37,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
record = kernel->tune(op, this);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = *perfData;
|
||||
record = perfData;
|
||||
|
||||
if (!profiling) {
|
||||
kernel->compute(op, record, this);
|
||||
|
@ -69,7 +68,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
||||
PerfRecord record;
|
||||
// Tune the kernel if there is no record
|
||||
|
@ -77,7 +76,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
record = kernel->tune(op, this);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = *perfData;
|
||||
record = perfData;
|
||||
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
namespace infini {
|
||||
|
||||
void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||
|
@ -17,7 +18,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
KernelAttrs{device, op->getOpType(), DataType::Float32};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op, this);
|
||||
continue;
|
||||
|
@ -28,10 +29,10 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
record = kernel->tune(op, this);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = *perfData;
|
||||
|
||||
record = perfData;
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
json j;
|
||||
|
||||
if (profiling) {
|
||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||
|
|
|
@ -7,30 +7,47 @@
|
|||
#include <tuple>
|
||||
namespace infini {
|
||||
|
||||
static constexpr int N_ALGO = 8;
|
||||
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[N_ALGO] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
|
||||
static constexpr int N_MODE = 2;
|
||||
static constexpr cudnnConvolutionMode_t MODES[N_MODE] = {
|
||||
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||
|
||||
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
||||
int mode = 1;
|
||||
size_t workspaceSize = 100000;
|
||||
bool fuseAct = false;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 1;
|
||||
j["data"] = std::make_tuple(algo, mode, fuseAct, time, workspaceSize);
|
||||
}
|
||||
static PerfRecord from_json(const json &j) {
|
||||
ConvCuDnnPerfRecordObj tmp;
|
||||
auto [Algo, Mode, FuseAct, Time, WorkspaceSize] =
|
||||
j["data"].get<tuple<int, int, bool, double, size_t>>();
|
||||
tmp.algo = Algo;
|
||||
tmp.mode = Mode;
|
||||
tmp.fuseAct = FuseAct;
|
||||
tmp.time = Time;
|
||||
tmp.workspaceSize = WorkspaceSize;
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
using ConvCuDnnPerfRecord = Ref<ConvCuDnnPerfRecordObj>;
|
||||
|
||||
class convCudnn : public Kernel {
|
||||
|
||||
static constexpr int N_ALGO = 8;
|
||||
static constexpr int N_MODE = 2;
|
||||
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[8] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
|
||||
|
||||
static constexpr cudnnConvolutionMode_t MODES[2] = {
|
||||
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
|
||||
|
||||
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
|
||||
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
|
||||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
|
@ -276,4 +293,5 @@ class convCudnn : public Kernel {
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn,
|
||||
"Conv_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json);
|
||||
} // namespace infini
|
|
@ -1,14 +1,24 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
|
||||
namespace infini {
|
||||
struct MatmulCudnnPerfRecordObj : public PerfRecordObj {
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
|
||||
struct MatmulCublasPerfRecordObj : public PerfRecordObj {
|
||||
int algo = CUBLAS_GEMM_DEFAULT;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 2;
|
||||
j["data"] = std::make_pair(algo, time);
|
||||
}
|
||||
static PerfRecord from_json(const json &j) {
|
||||
MatmulCublasPerfRecordObj tmp;
|
||||
auto pr = j["data"].get<pair<int, double>>();
|
||||
tmp.algo = pr.first;
|
||||
tmp.time = pr.second;
|
||||
return make_ref<MatmulCublasPerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
using MatmulCudnnPerfRecord = Ref<MatmulCudnnPerfRecordObj>;
|
||||
|
||||
constexpr int N_ALGO = 24;
|
||||
constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
||||
CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2,
|
||||
|
@ -20,7 +30,6 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
||||
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||
};
|
||||
|
||||
class matmulCublas : public Kernel {
|
||||
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
|
@ -29,7 +38,7 @@ class matmulCublas : public Kernel {
|
|||
void *const inAData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inBData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto record = as<MatmulCudnnPerfRecordObj>(_record);
|
||||
auto record = as<MatmulCublasPerfRecordObj>(_record);
|
||||
|
||||
const auto [b, m, n, k] = op->getBMNK();
|
||||
auto opA =
|
||||
|
@ -44,12 +53,13 @@ class matmulCublas : public Kernel {
|
|||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, k * n, inAData, CUDA_R_32F, lda, m * k, &beta,
|
||||
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F, record->algo);
|
||||
outData, CUDA_R_32F, ldc, m * n, b, CUDA_R_32F,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha, inBData, CUDA_R_32F, ldb, inAData,
|
||||
CUDA_R_32F, lda, &beta, outData, CUDA_R_32F,
|
||||
ldc, CUDA_R_32F, record->algo);
|
||||
stat = cublasGemmEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha, inBData,
|
||||
CUDA_R_32F, ldb, inAData, CUDA_R_32F, lda, &beta, outData,
|
||||
CUDA_R_32F, ldc, CUDA_R_32F, (cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
return (stat == CUBLAS_STATUS_SUCCESS);
|
||||
}
|
||||
|
@ -61,7 +71,7 @@ class matmulCublas : public Kernel {
|
|||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
auto record =
|
||||
make_ref<MatmulCudnnPerfRecordObj>(); // use default record;
|
||||
make_ref<MatmulCublasPerfRecordObj>(); // use default record;
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
|
@ -69,10 +79,10 @@ class matmulCublas : public Kernel {
|
|||
const RuntimeObj *_context) const override {
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<MatmulObj>(_op);
|
||||
auto ret = make_ref<MatmulCudnnPerfRecordObj>();
|
||||
auto ret = make_ref<MatmulCublasPerfRecordObj>();
|
||||
ret->time = std::numeric_limits<double>::max();
|
||||
for (int i = 0; i < N_ALGO; i++) {
|
||||
auto rcd = make_ref<MatmulCudnnPerfRecordObj>();
|
||||
auto rcd = make_ref<MatmulCublasPerfRecordObj>();
|
||||
rcd->algo = ALGOS[i];
|
||||
if (!do_compute(_op, rcd, _context))
|
||||
continue;
|
||||
|
@ -91,4 +101,5 @@ class matmulCublas : public Kernel {
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Matmul, DataType::Float32, matmulCublas,
|
||||
"Matmul_cuBLAS_CUDA_Float32");
|
||||
|
||||
REGISTER_CONSTRUCTOR(2, MatmulCublasPerfRecordObj::from_json);
|
||||
}; // namespace infini
|
|
@ -0,0 +1,59 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(PerfEngine, save_and_load) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
|
||||
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
ACpu->setData(IncrementalGenerator());
|
||||
BCpu->setData(IncrementalGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
|
||||
|
||||
gCuda->dataMalloc();
|
||||
cudaRuntime->run(gCuda, true);
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
|
||||
json j0 = perfEngine;
|
||||
std::cout << "PerfEngine saveed:" << std::endl;
|
||||
std::cout << j0 << std::endl;
|
||||
perfEngine.savePerfEngineData("test.json");
|
||||
perfEngine.loadPerfEngineData("test.json");
|
||||
json j1 = perfEngine;
|
||||
std::cout << "PerfEngine loaded:" << std::endl;
|
||||
std::cout << j1 << std::endl;
|
||||
EXPECT_TRUE(j0 == j1);
|
||||
}
|
||||
} // namespace infini
|
Loading…
Reference in New Issue