From adb1bce1bc3151799ce40fe5cc3ba9b1fb149d02 Mon Sep 17 00:00:00 2001 From: zweien <278954153@qq.com> Date: Tue, 19 Oct 2021 16:20:27 +0800 Subject: [PATCH] feat: set gpu device explicitly. - add `use_gpu()`, `use_gpu()` in `__init__.py` --- idrlnet/__init__.py | 43 +++++++++++++++++++++++++++++++++++-------- idrlnet/shortcut.py | 2 +- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/idrlnet/__init__.py b/idrlnet/__init__.py index 169e941..d089a5c 100644 --- a/idrlnet/__init__.py +++ b/idrlnet/__init__.py @@ -1,16 +1,43 @@ import torch +from .header import logger +GPU_AVAILABLE = False +GPU_ENABLED = False # todo more careful check -GPU_ENABLED = True if torch.cuda.is_available(): try: _ = torch.Tensor([0.0, 0.0]).cuda() - torch.set_default_tensor_type("torch.cuda.FloatTensor") - print("gpu available") - GPU_ENABLED = True + logger.info("GPU available") + GPU_AVAILABLE = True except: - print("gpu not available") - GPU_ENABLED = False + logger.info("GPU not available") + GPU_AVAILABLE = False else: - print("gpu not available") - GPU_ENABLED = False + logger.info("GPU not available") + GPU_AVAILABLE = False + + +def use_gpu(device=0): + """Use GPU with device `device`. + + Args: + device (torch.device or int): selected device. + """ + if GPU_AVAILABLE: + try: + torch.cuda.set_device(device) + torch.set_default_tensor_type("torch.cuda.FloatTensor") + logger.info(f"Using GPU device {device}") + global GPU_ENABLED + GPU_ENABLED = True + except: + logger.warning("Invalid device ordinal") + + +def use_cpu(): + """ + Use CPU. + """ + if GPU_ENABLED: + torch.set_default_tensor_type("torch.FloatTensor") + logger.info(f"Using CPU") diff --git a/idrlnet/shortcut.py b/idrlnet/shortcut.py index 88ed354..a9905ee 100644 --- a/idrlnet/shortcut.py +++ b/idrlnet/shortcut.py @@ -10,4 +10,4 @@ from idrlnet.callbacks import GradientReceiver from idrlnet.receivers import Receiver, Signal from idrlnet.variable import Variables, export_var from idrlnet.header import logger -from idrlnet import GPU_ENABLED +from idrlnet import GPU_AVAILABLE, GPU_ENABLED, use_gpu