diff --git a/idrlnet/__init__.py b/idrlnet/__init__.py index 169e941..d089a5c 100644 --- a/idrlnet/__init__.py +++ b/idrlnet/__init__.py @@ -1,16 +1,43 @@ import torch +from .header import logger +GPU_AVAILABLE = False +GPU_ENABLED = False # todo more careful check -GPU_ENABLED = True if torch.cuda.is_available(): try: _ = torch.Tensor([0.0, 0.0]).cuda() - torch.set_default_tensor_type("torch.cuda.FloatTensor") - print("gpu available") - GPU_ENABLED = True + logger.info("GPU available") + GPU_AVAILABLE = True except: - print("gpu not available") - GPU_ENABLED = False + logger.info("GPU not available") + GPU_AVAILABLE = False else: - print("gpu not available") - GPU_ENABLED = False + logger.info("GPU not available") + GPU_AVAILABLE = False + + +def use_gpu(device=0): + """Use GPU with device `device`. + + Args: + device (torch.device or int): selected device. + """ + if GPU_AVAILABLE: + try: + torch.cuda.set_device(device) + torch.set_default_tensor_type("torch.cuda.FloatTensor") + logger.info(f"Using GPU device {device}") + global GPU_ENABLED + GPU_ENABLED = True + except: + logger.warning("Invalid device ordinal") + + +def use_cpu(): + """ + Use CPU. + """ + if GPU_ENABLED: + torch.set_default_tensor_type("torch.FloatTensor") + logger.info(f"Using CPU") diff --git a/idrlnet/shortcut.py b/idrlnet/shortcut.py index 88ed354..a9905ee 100644 --- a/idrlnet/shortcut.py +++ b/idrlnet/shortcut.py @@ -10,4 +10,4 @@ from idrlnet.callbacks import GradientReceiver from idrlnet.receivers import Receiver, Signal from idrlnet.variable import Variables, export_var from idrlnet.header import logger -from idrlnet import GPU_ENABLED +from idrlnet import GPU_AVAILABLE, GPU_ENABLED, use_gpu