Merge pull request #1690 from billvsme/main

Improve get_current_device
This commit is contained in:
hoshi-hiyouga 2023-12-01 15:44:35 +08:00 committed by GitHub
commit d043a4e7ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def get_current_device() -> str:
import accelerate
dummy_accelerator = accelerate.Accelerator()
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
return "xpu:{}".format(local_rank)
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
return local_rank if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList":