fix llamafactory-cli env

This commit is contained in:
hiyouga 2024-06-08 07:15:45 +08:00
parent 3ac11e77cc
commit 972ec9c668
1 changed files with 10 additions and 7 deletions

View File

@ -6,10 +6,7 @@ import peft
import torch import torch
import transformers import transformers
import trl import trl
from transformers.integrations import is_deepspeed_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available
from .packages import is_vllm_available
VERSION = "0.8.1.dev0" VERSION = "0.8.1.dev0"
@ -37,19 +34,25 @@ def print_env() -> None:
info["NPU type"] = torch.npu.get_device_name() info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann info["CANN version"] = torch.version.cann
if is_deepspeed_available(): try:
import deepspeed # type: ignore import deepspeed # type: ignore
info["DeepSpeed version"] = deepspeed.__version__ info["DeepSpeed version"] = deepspeed.__version__
except Exception:
pass
if is_bitsandbytes_available(): try:
import bitsandbytes import bitsandbytes
info["Bitsandbytes version"] = bitsandbytes.__version__ info["Bitsandbytes version"] = bitsandbytes.__version__
except Exception:
pass
if is_vllm_available(): try:
import vllm import vllm
info["vLLM version"] = vllm.__version__ info["vLLM version"] = vllm.__version__
except Exception:
pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")