diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 672110cf..2e8d16a8 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def get_current_device() -> str: import accelerate - dummy_accelerator = accelerate.Accelerator() + local_rank = int(os.environ.get('LOCAL_RANK', '0')) if accelerate.utils.is_xpu_available(): - return "xpu:{}".format(dummy_accelerator.local_process_index) + return "xpu:{}".format(local_rank) else: - return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu" + return local_rank if torch.cuda.is_available() else "cpu" def get_logits_processor() -> "LogitsProcessorList":