forked from idrl/idrlnet
feat: set gpu device explicitly.
- add `use_gpu()`, `use_gpu()` in `__init__.py`
This commit is contained in:
parent
7579a8fb74
commit
adb1bce1bc
|
@ -1,16 +1,43 @@
|
|||
import torch
|
||||
from .header import logger
|
||||
|
||||
GPU_AVAILABLE = False
|
||||
GPU_ENABLED = False
|
||||
# todo more careful check
|
||||
GPU_ENABLED = True
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
_ = torch.Tensor([0.0, 0.0]).cuda()
|
||||
logger.info("GPU available")
|
||||
GPU_AVAILABLE = True
|
||||
except:
|
||||
logger.info("GPU not available")
|
||||
GPU_AVAILABLE = False
|
||||
else:
|
||||
logger.info("GPU not available")
|
||||
GPU_AVAILABLE = False
|
||||
|
||||
|
||||
def use_gpu(device=0):
|
||||
"""Use GPU with device `device`.
|
||||
|
||||
Args:
|
||||
device (torch.device or int): selected device.
|
||||
"""
|
||||
if GPU_AVAILABLE:
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
||||
print("gpu available")
|
||||
logger.info(f"Using GPU device {device}")
|
||||
global GPU_ENABLED
|
||||
GPU_ENABLED = True
|
||||
except:
|
||||
print("gpu not available")
|
||||
GPU_ENABLED = False
|
||||
else:
|
||||
print("gpu not available")
|
||||
GPU_ENABLED = False
|
||||
logger.warning("Invalid device ordinal")
|
||||
|
||||
|
||||
def use_cpu():
|
||||
"""
|
||||
Use CPU.
|
||||
"""
|
||||
if GPU_ENABLED:
|
||||
torch.set_default_tensor_type("torch.FloatTensor")
|
||||
logger.info(f"Using CPU")
|
||||
|
|
|
@ -10,4 +10,4 @@ from idrlnet.callbacks import GradientReceiver
|
|||
from idrlnet.receivers import Receiver, Signal
|
||||
from idrlnet.variable import Variables, export_var
|
||||
from idrlnet.header import logger
|
||||
from idrlnet import GPU_ENABLED
|
||||
from idrlnet import GPU_AVAILABLE, GPU_ENABLED, use_gpu
|
||||
|
|
Loading…
Reference in New Issue