From d4908d57085bbcfcd29e0a8d4ee6425318ee4285 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 01:28:14 +0800 Subject: [PATCH] add llamafactory-cli env --- .github/ISSUE_TEMPLATE/bug-report.yml | 6 +-- src/llamafactory/cli.py | 6 ++- src/llamafactory/extras/env.py | 54 +++++++++++++++++++++++++++ src/llamafactory/extras/packages.py | 4 ++ 4 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 src/llamafactory/extras/env.py diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 82620fdb..1d962200 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -20,10 +20,10 @@ body: attributes: label: System Info description: | - Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below. - 请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。 + Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below. + 请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。 - placeholder: transformers version, platform, python version, ... + placeholder: llamafactory version, platform, python version, ... - type: textarea id: reproduction diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index c14ae6ec..fbe18d86 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -8,6 +8,7 @@ from . import launcher from .api.app import run_api from .chat.chat_model import run_chat from .eval.evaluator import run_eval +from .extras.env import VERSION, print_env from .extras.logging import get_logger from .extras.misc import get_device_count from .train.tuner import export_model, run_exp @@ -29,8 +30,6 @@ USAGE = ( + "-" * 70 ) -VERSION = "0.7.2.dev0" - WELCOME = ( "-" * 58 + "\n" @@ -50,6 +49,7 @@ logger = get_logger(__name__) class Command(str, Enum): API = "api" CHAT = "chat" + ENV = "env" EVAL = "eval" EXPORT = "export" TRAIN = "train" @@ -65,6 +65,8 @@ def main(): run_api() elif command == Command.CHAT: run_chat() + elif command == Command.ENV: + print_env() elif command == Command.EVAL: run_eval() elif command == Command.EXPORT: diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py new file mode 100644 index 00000000..27453a6b --- /dev/null +++ b/src/llamafactory/extras/env.py @@ -0,0 +1,54 @@ +import platform + +import accelerate +import datasets +import peft +import torch +import transformers +import trl +from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available + +from .packages import is_deepspeed_available, is_vllm_available + + +VERSION = "0.7.2.dev0" + + +def print_env() -> None: + info = { + "`llamafactory` version": VERSION, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": torch.__version__, + "Transformers version": transformers.__version__, + "Datasets version": datasets.__version__, + "Accelerate version": accelerate.__version__, + "PEFT version": peft.__version__, + "TRL version": trl.__version__, + } + + if is_torch_cuda_available(): + info["PyTorch version"] += " (GPU)" + info["GPU type"] = torch.cuda.get_device_name() + + if is_torch_npu_available(): + info["PyTorch version"] += " (NPU)" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann + + if is_deepspeed_available(): + import deepspeed # type: ignore + + info["DeepSpeed version"] = deepspeed.__version__ + + if is_bitsandbytes_available(): + import bitsandbytes + + info["Bitsandbytes version"] = bitsandbytes.__version__ + + if is_vllm_available(): + import vllm + + info["vLLM version"] = vllm.__version__ + + print("\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 4c9e6492..fe056e2d 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -20,6 +20,10 @@ def _get_package_version(name: str) -> "Version": return version.parse("0.0.0") +def is_deepspeed_available(): + return _is_package_available("deepspeed") + + def is_fastapi_available(): return _is_package_available("fastapi")