diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 96e25f9b..924db59e 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -17,16 +17,19 @@ class CudaRuntimeObj : public RuntimeObj { public: CudaRuntimeObj() : RuntimeObj(Device::CUDA) { + // Prepare for nvrtc. cuCtxCreate should be called befero others. + // Otherwise it will result in strange failure, such as cuBLAS failed on + // certian inputs. + checkCUresult(cuInit(0)); + checkCUresult(cuDeviceGet(&cuDevice, 0)); + checkCUresult(cuCtxCreate(&newContext, 0, cuDevice)); + checkCudnnError(cudnnCreate(&cudnn)); checkCublasError(cublasCreate(&cublas)); // 10GB for Longformer // size_t longformerNum = 3lu * (1 << 30); workspaceSize = 7ll << 30; // 7 GB workspace = alloc(workspaceSize); - - checkCUresult(cuInit(0)); - checkCUresult(cuDeviceGet(&cuDevice, 0)); - checkCUresult(cuCtxCreate(&newContext, 0, cuDevice)); } virtual ~CudaRuntimeObj() { dealloc(workspace); @@ -49,6 +52,7 @@ class CudaRuntimeObj : public RuntimeObj { void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); } cudnnHandle_t cudnnHandle() const { return cudnn; } cublasHandle_t cublasHandle() const { return cublas; } + size_t getWorkspaceSize() const { return workspaceSize; } CudaPtr getWorkspace(size_t size) const { IT_ASSERT(size <= workspaceSize); return workspace; diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index f6afa20d..c020ed33 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -240,7 +240,8 @@ class convCudnn : public Kernel { ALGOS[record.algo], &record.workspaceSize); if (stat != CUDNN_STATUS_SUCCESS) continue; - + if (record.workspaceSize > context->getWorkspaceSize()) + continue; CudaPtr wsData = context->getWorkspace(record.workspaceSize); float alpha = 1.f, beta = 0.f; diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 4d5debdb..8e82ccc7 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -17,7 +17,7 @@ string MatmulObj::toString() const { os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B") << ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid() << ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid() - << ")"; + << ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"; return os.str(); }