feat: set gpu device explicitly.

- add `use_gpu()`, `use_gpu()` in `__init__.py`
This commit is contained in:
zweien 2021-10-19 16:20:27 +08:00
parent 7579a8fb74
commit adb1bce1bc
2 changed files with 36 additions and 9 deletions

View File

@ -1,16 +1,43 @@
import torch
from .header import logger
GPU_AVAILABLE = False
GPU_ENABLED = False
# todo more careful check
GPU_ENABLED = True
if torch.cuda.is_available():
try:
_ = torch.Tensor([0.0, 0.0]).cuda()
torch.set_default_tensor_type("torch.cuda.FloatTensor")
print("gpu available")
GPU_ENABLED = True
logger.info("GPU available")
GPU_AVAILABLE = True
except:
print("gpu not available")
GPU_ENABLED = False
logger.info("GPU not available")
GPU_AVAILABLE = False
else:
print("gpu not available")
GPU_ENABLED = False
logger.info("GPU not available")
GPU_AVAILABLE = False
def use_gpu(device=0):
"""Use GPU with device `device`.
Args:
device (torch.device or int): selected device.
"""
if GPU_AVAILABLE:
try:
torch.cuda.set_device(device)
torch.set_default_tensor_type("torch.cuda.FloatTensor")
logger.info(f"Using GPU device {device}")
global GPU_ENABLED
GPU_ENABLED = True
except:
logger.warning("Invalid device ordinal")
def use_cpu():
"""
Use CPU.
"""
if GPU_ENABLED:
torch.set_default_tensor_type("torch.FloatTensor")
logger.info(f"Using CPU")

View File

@ -10,4 +10,4 @@ from idrlnet.callbacks import GradientReceiver
from idrlnet.receivers import Receiver, Signal
from idrlnet.variable import Variables, export_var
from idrlnet.header import logger
from idrlnet import GPU_ENABLED
from idrlnet import GPU_AVAILABLE, GPU_ENABLED, use_gpu