diff --git a/src/llmtuner/cli.py b/src/llmtuner/cli.py index 1b5bd658..f2619ab9 100644 --- a/src/llmtuner/cli.py +++ b/src/llmtuner/cli.py @@ -1,6 +1,7 @@ import sys from enum import Enum, unique +from . import __version__ from .api.app import run_api from .chat.chat_model import run_chat from .eval.evaluator import run_eval @@ -8,6 +9,19 @@ from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui +USAGE = """ +Usage: + llamafactory-cli api -h: launch an API server + llamafactory-cli chat -h: launch a chat interface in CLI + llamafactory-cli eval -h: do evaluation + llamafactory-cli export -h: merge LoRA adapters and export model + llamafactory-cli train -h: do training + llamafactory-cli webchat -h: launch a chat interface in Web UI + llamafactory-cli webui: launch LlamaBoard + llamafactory-cli version: show version info +""" + + @unique class Command(str, Enum): API = "api" @@ -17,6 +31,8 @@ class Command(str, Enum): TRAIN = "train" WEBDEMO = "webchat" WEBUI = "webui" + VERSION = "version" + HELP = "help" def main(): @@ -35,5 +51,9 @@ def main(): run_web_demo() elif command == Command.WEBUI: run_web_ui() + elif command == Command.VERSION: + print("Welcome to LLaMA Factory, version {}".format(__version__)) + elif command == Command.HELP: + print(USAGE) else: raise NotImplementedError("Unknown command: {}".format(command))