Fix: cuCtxCreate before other initialization (#49)

Fix: create cuCtx at the very beginning

Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
zhengly123 2022-10-19 15:41:48 +08:00 committed by GitHub
parent 00b2f18c17
commit 63d8aff985
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 6 deletions

View File

@ -17,16 +17,19 @@ class CudaRuntimeObj : public RuntimeObj {
public:
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
// Prepare for nvrtc. cuCtxCreate should be called befero others.
// Otherwise it will result in strange failure, such as cuBLAS failed on
// certian inputs.
checkCUresult(cuInit(0));
checkCUresult(cuDeviceGet(&cuDevice, 0));
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
checkCudnnError(cudnnCreate(&cudnn));
checkCublasError(cublasCreate(&cublas));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
workspace = alloc(workspaceSize);
checkCUresult(cuInit(0));
checkCUresult(cuDeviceGet(&cuDevice, 0));
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
}
virtual ~CudaRuntimeObj() {
dealloc(workspace);
@ -49,6 +52,7 @@ class CudaRuntimeObj : public RuntimeObj {
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
cudnnHandle_t cudnnHandle() const { return cudnn; }
cublasHandle_t cublasHandle() const { return cublas; }
size_t getWorkspaceSize() const { return workspaceSize; }
CudaPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
return workspace;

View File

@ -240,7 +240,8 @@ class convCudnn : public Kernel {
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS)
continue;
if (record.workspaceSize > context->getWorkspaceSize())
continue;
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;

View File

@ -17,7 +17,7 @@ string MatmulObj::toString() const {
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ")";
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
return os.str();
}