diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 33efb7d2..e1ae7d9f 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -73,7 +73,7 @@ def get_current_device() -> str: if accelerate.utils.is_xpu_available(): return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) elif accelerate.utils.is_npu_available() or torch.cuda.is_available(): - return os.environ.get("LOCAL_RANK", "0") + return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) else: return "cpu"