commit
d043a4e7ba
|
@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
|
||||
def get_current_device() -> str:
|
||||
import accelerate
|
||||
dummy_accelerator = accelerate.Accelerator()
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
|
||||
if accelerate.utils.is_xpu_available():
|
||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||
return "xpu:{}".format(local_rank)
|
||||
else:
|
||||
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
||||
return local_rank if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
|
|
Loading…
Reference in New Issue