forked from jiuyuan/InfiniTensor
Fix: cuCtxCreate before other initialization (#49)
Fix: create cuCtx at the very beginning Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
00b2f18c17
commit
63d8aff985
|
@ -17,16 +17,19 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
|
||||
public:
|
||||
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
|
||||
// Prepare for nvrtc. cuCtxCreate should be called befero others.
|
||||
// Otherwise it will result in strange failure, such as cuBLAS failed on
|
||||
// certian inputs.
|
||||
checkCUresult(cuInit(0));
|
||||
checkCUresult(cuDeviceGet(&cuDevice, 0));
|
||||
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
|
||||
|
||||
checkCudnnError(cudnnCreate(&cudnn));
|
||||
checkCublasError(cublasCreate(&cublas));
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
workspace = alloc(workspaceSize);
|
||||
|
||||
checkCUresult(cuInit(0));
|
||||
checkCUresult(cuDeviceGet(&cuDevice, 0));
|
||||
checkCUresult(cuCtxCreate(&newContext, 0, cuDevice));
|
||||
}
|
||||
virtual ~CudaRuntimeObj() {
|
||||
dealloc(workspace);
|
||||
|
@ -49,6 +52,7 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
|
||||
cudnnHandle_t cudnnHandle() const { return cudnn; }
|
||||
cublasHandle_t cublasHandle() const { return cublas; }
|
||||
size_t getWorkspaceSize() const { return workspaceSize; }
|
||||
CudaPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
|
|
|
@ -240,7 +240,8 @@ class convCudnn : public Kernel {
|
|||
ALGOS[record.algo], &record.workspaceSize);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
continue;
|
||||
|
||||
if (record.workspaceSize > context->getWorkspaceSize())
|
||||
continue;
|
||||
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ string MatmulObj::toString() const {
|
|||
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
|
||||
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
|
||||
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
|
||||
<< ")";
|
||||
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue