2022-08-22 15:01:03 +08:00
|
|
|
#pragma once
|
|
|
|
#include "core/runtime.h"
|
|
|
|
#include "cuda/cuda_common.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
|
|
|
class CudaRuntimeObj : public RuntimeObj {
|
|
|
|
private:
|
|
|
|
cudnnHandle_t cudnn;
|
|
|
|
cublasHandle_t cublas;
|
|
|
|
CudaPtr workspace;
|
|
|
|
size_t workspaceSize;
|
|
|
|
|
2022-09-22 18:06:45 +08:00
|
|
|
public:
|
|
|
|
CUdevice cuDevice;
|
|
|
|
CUcontext newContext;
|
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
public:
|
|
|
|
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
|
|
|
|
checkCudnnError(cudnnCreate(&cudnn));
|
|
|
|
checkCublasError(cublasCreate(&cublas));
|
|
|
|
// 10GB for Longformer
|
|
|
|
// size_t longformerNum = 3lu * (1 << 30);
|
|
|
|
workspaceSize = 7ll << 30; // 7 GB
|
|
|
|
workspace = alloc(workspaceSize);
|
2022-09-22 18:06:45 +08:00
|
|
|
|
|
|
|
checkCUresult(cuInit(0));
|
|
|
|
checkCUresult(cuDeviceGet(&cuDevice, 0));
|
|
|
|
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
|
2022-08-22 15:01:03 +08:00
|
|
|
}
|
|
|
|
virtual ~CudaRuntimeObj() {
|
2022-08-29 21:37:07 +08:00
|
|
|
dealloc(workspace);
|
2022-08-22 15:01:03 +08:00
|
|
|
checkCudnnError(cudnnDestroy(cudnn));
|
|
|
|
checkCublasError(cublasDestroy(cublas));
|
2022-09-22 18:06:45 +08:00
|
|
|
checkCUresult(cuCtxDestroy(newContext));
|
2022-08-22 15:01:03 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void run(const Graph &graph, bool tune = false,
|
|
|
|
bool profiling = false) const;
|
|
|
|
// double runEvaluation(const Graph &graph, int nWarmups,
|
|
|
|
// int nEvaluations) const;
|
|
|
|
void sync() const;
|
|
|
|
CudaPtr alloc(size_t size) override {
|
|
|
|
void *ptr;
|
|
|
|
checkCudaError(cudaMalloc(&ptr, size));
|
|
|
|
return ptr;
|
|
|
|
}
|
|
|
|
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
|
|
|
|
cudnnHandle_t cudnnHandle() const { return cudnn; }
|
|
|
|
cublasHandle_t cublasHandle() const { return cublas; }
|
|
|
|
CudaPtr getWorkspace(size_t size) const {
|
|
|
|
IT_ASSERT(size <= workspaceSize);
|
|
|
|
return workspace;
|
|
|
|
}
|
|
|
|
|
2022-09-22 18:06:45 +08:00
|
|
|
void copyBlobFromCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const override {
|
2022-08-23 16:55:59 +08:00
|
|
|
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyHostToDevice));
|
|
|
|
}
|
|
|
|
|
2022-09-22 18:06:45 +08:00
|
|
|
void copyBlobToCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const override {
|
2022-08-23 16:55:59 +08:00
|
|
|
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost));
|
|
|
|
}
|
|
|
|
|
2022-09-22 18:06:45 +08:00
|
|
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
2022-08-23 16:55:59 +08:00
|
|
|
size_t bytes) const override {
|
|
|
|
checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice));
|
|
|
|
}
|
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
private:
|
2022-08-29 21:37:07 +08:00
|
|
|
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
2022-08-22 15:01:03 +08:00
|
|
|
};
|
|
|
|
} // namespace infini
|