forked from idrl/idrlnet
feat: set gpu device explicitly.
- add `use_gpu()`, `use_gpu()` in `__init__.py`
This commit is contained in:
parent
7579a8fb74
commit
adb1bce1bc
|
@ -1,16 +1,43 @@
|
||||||
import torch
|
import torch
|
||||||
|
from .header import logger
|
||||||
|
|
||||||
|
GPU_AVAILABLE = False
|
||||||
|
GPU_ENABLED = False
|
||||||
# todo more careful check
|
# todo more careful check
|
||||||
GPU_ENABLED = True
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
_ = torch.Tensor([0.0, 0.0]).cuda()
|
_ = torch.Tensor([0.0, 0.0]).cuda()
|
||||||
|
logger.info("GPU available")
|
||||||
|
GPU_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
logger.info("GPU not available")
|
||||||
|
GPU_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
logger.info("GPU not available")
|
||||||
|
GPU_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def use_gpu(device=0):
|
||||||
|
"""Use GPU with device `device`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int): selected device.
|
||||||
|
"""
|
||||||
|
if GPU_AVAILABLE:
|
||||||
|
try:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
||||||
print("gpu available")
|
logger.info(f"Using GPU device {device}")
|
||||||
|
global GPU_ENABLED
|
||||||
GPU_ENABLED = True
|
GPU_ENABLED = True
|
||||||
except:
|
except:
|
||||||
print("gpu not available")
|
logger.warning("Invalid device ordinal")
|
||||||
GPU_ENABLED = False
|
|
||||||
else:
|
|
||||||
print("gpu not available")
|
def use_cpu():
|
||||||
GPU_ENABLED = False
|
"""
|
||||||
|
Use CPU.
|
||||||
|
"""
|
||||||
|
if GPU_ENABLED:
|
||||||
|
torch.set_default_tensor_type("torch.FloatTensor")
|
||||||
|
logger.info(f"Using CPU")
|
||||||
|
|
|
@ -10,4 +10,4 @@ from idrlnet.callbacks import GradientReceiver
|
||||||
from idrlnet.receivers import Receiver, Signal
|
from idrlnet.receivers import Receiver, Signal
|
||||||
from idrlnet.variable import Variables, export_var
|
from idrlnet.variable import Variables, export_var
|
||||||
from idrlnet.header import logger
|
from idrlnet.header import logger
|
||||||
from idrlnet import GPU_ENABLED
|
from idrlnet import GPU_AVAILABLE, GPU_ENABLED, use_gpu
|
||||||
|
|
Loading…
Reference in New Issue