diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index c1395552..9d1cbd0b 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -37,7 +37,7 @@ from .logging import get_logger _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() try: - _is_bf16_available = is_torch_bf16_gpu_available() + _is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported()) except Exception: _is_bf16_available = False