From 40dfcbc3d4571ce022b6aa39db581c8b88a75b8d Mon Sep 17 00:00:00 2001 From: billvsme <994171686@qq.com> Date: Thu, 30 Nov 2023 22:40:35 +0800 Subject: [PATCH] improve get_current_device --- src/llmtuner/extras/misc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 672110cf..2e8d16a8 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def get_current_device() -> str: import accelerate - dummy_accelerator = accelerate.Accelerator() + local_rank = int(os.environ.get('LOCAL_RANK', '0')) if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(dummy_accelerator.local_process_index) + return "xpu:{}".format(local_rank) else: - return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu" + return local_rank if torch.cuda.is_available() else "cpu" def get_logits_processor() -> "LogitsProcessorList":