update webui and add CLIs

This commit is contained in:
hiyouga 2024-05-03 02:58:23 +08:00
parent 39e964a97a
commit 245fe47ece
65 changed files with 363 additions and 372 deletions

View File

@ -11,4 +11,4 @@ RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
EXPOSE 7860 EXPOSE 7860
CMD [ "python", "src/train_web.py" ] CMD [ "llamafactory-cli webui" ]

View File

@ -346,7 +346,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
```bash ```bash
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface llamafactory-cli webui
``` ```
<details><summary>For Alibaba Cloud users</summary> <details><summary>For Alibaba Cloud users</summary>
@ -392,12 +392,12 @@ docker compose -f ./docker-compose.yml up -d
See [examples/README.md](examples/README.md) for usage. See [examples/README.md](examples/README.md) for usage.
Use `python src/train_bash.py -h` to display arguments description. Use `llamafactory-cli train -h` to display arguments description.
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \ --template llama3 \
--infer_backend vllm \ --infer_backend vllm \

View File

@ -346,7 +346,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
```bash ```bash
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0` export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860` export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface llamafactory-cli webui
``` ```
<details><summary>阿里云用户指南</summary> <details><summary>阿里云用户指南</summary>
@ -392,12 +392,12 @@ docker compose -f ./docker-compose.yml up -d
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。 使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
您可以执行 `python src/train_bash.py -h` 来查看参数文档。 您可以执行 `llamafactory-cli train -h` 来查看参数文档。
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \ --template llama3 \
--infer_backend vllm \ --infer_backend vllm \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -7,7 +7,7 @@ pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file ../../accelerate/fsdp_config.yaml \ --config_file ../../accelerate/fsdp_config.yaml \
../../../src/train_bash.py \ ../../../src/train.py \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-70b-hf \ --model_name_or_path meta-llama/Llama-2-70b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path ../../../models/llama2-7b-pro \ --model_name_or_path ../../../models/llama2-7b-pro \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -6,7 +6,7 @@ python -m torch.distributed.run \
--node_rank $RANK \ --node_rank $RANK \
--master_addr $MASTER_ADDR \ --master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \ --master_port $MASTER_PORT \
../../src/train_bash.py \ ../../src/train.py \
--deepspeed ../deepspeed/ds_z3_config.json \ --deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \ --stage sft \
--do_train \ --do_train \

View File

@ -2,7 +2,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \ --config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \ ../../src/train.py \
--stage sft \ --stage sft \
--do_predict \ --do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \ --model_name_or_path ../../saves/LLaMA2-7B/full/sft \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
deepspeed --num_gpus 4 ../../src/train_bash.py \ deepspeed --num_gpus 4 ../../src/train.py \
--deepspeed ../deepspeed/ds_z3_config.json \ --deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \ --stage sft \
--do_train \ --do_train \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \ CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \ --template default \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \ --template default \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template fewshot \ --template fewshot \

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# add `--visual_inputs True` to load MLLM # add `--visual_inputs True` to load MLLM
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \ --template default \

View File

@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
# ZeRO-3 enables weight sharding on multiple GPUs
deepspeed --num_gpus 4 ../../src/train_bash.py \ deepspeed --num_gpus 4 ../../src/train.py \
--deepspeed ../deepspeed/ds_z3_config.json \ --deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \ --stage sft \
--do_train \ --do_train \

View File

@ -3,7 +3,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/master_config.yaml \ --config_file ../accelerate/master_config.yaml \
../../src/train_bash.py \ ../../src/train.py \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -2,7 +2,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \ --config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \ ../../src/train.py \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage dpo \ --stage dpo \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage orpo \ --stage orpo \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage ppo \ --stage ppo \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_predict \ --do_predict \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# use `--tokenized_path` in training script to load data # use `--tokenized_path` in training script to load data
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES= llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage pt \ --stage pt \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage rm \ --stage rm \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \ --model_name_or_path llava-hf/llava-1.5-7b-hf \

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# DO NOT use quantized model or quantization_bit when merging lora weights # DO NOT use quantized model or quantization_bit when merging lora weights
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \ --adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \ --template default \

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# NEED TO run `merge.sh` before using this script # NEED TO run `merge.sh` before using this script
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
--model_name_or_path ../../models/llama2-7b-sft \ --model_name_or_path ../../models/llama2-7b-sft \
--template default \ --template default \
--export_dir ../../models/llama2-7b-sft-int4 \ --export_dir ../../models/llama2-7b-sft-int4 \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \ --model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path TheBloke/Llama-2-7B-AWQ \ --model_name_or_path TheBloke/Llama-2-7B-AWQ \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \ --model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \ --stage sft \
--do_train \ --do_train \
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \ --model_name_or_path TheBloke/Llama-2-7B-GPTQ \

View File

@ -16,3 +16,4 @@ sse-starlette
matplotlib matplotlib
fire fire
packaging packaging
pyyaml

View File

@ -52,6 +52,7 @@ def main():
python_requires=">=3.8.0", python_requires=">=3.8.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]},
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@ -1,16 +0,0 @@
import os
import uvicorn
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__":
main()

View File

@ -1,49 +0,0 @@
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
def main():
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
from llmtuner import Evaluator
def main():
Evaluator().eval()
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
from llmtuner import export_model
def main():
export_model()
if __name__ == "__main__":
main()

View File

@ -1,11 +1,3 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams # Level: api, webui > chat, eval, train > data, model > extras, hparams
from .api import create_app __version__ = "0.7.1.dev0"
from .chat import ChatModel
from .eval import Evaluator
from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.7.0"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@ -1,4 +0,0 @@
from .app import create_app
__all__ = ["create_app"]

View File

@ -224,7 +224,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return app return app
if __name__ == "__main__": def run_api():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@ -2,6 +2,7 @@ import asyncio
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.misc import torch_gc
from ..hparams import get_infer_args from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine from .vllm_engine import VllmEngine
@ -95,3 +96,45 @@ class ChatModel:
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)
def run_chat():
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})

39
src/llmtuner/cli.py Normal file
View File

@ -0,0 +1,39 @@
import sys
from enum import Enum, unique
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
def main():
command = sys.argv.pop(1)
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
else:
raise NotImplementedError("Unknown command: {}".format(command))

View File

@ -1,4 +0,0 @@
from .evaluator import Evaluator
__all__ = ["Evaluator"]

View File

@ -118,6 +118,6 @@ class Evaluator:
f.write(score_info) f.write(score_info)
if __name__ == "__main__": def run_eval():
evaluator = Evaluator() evaluator = Evaluator()
evaluator.eval() evaluator.eval()

View File

@ -1,14 +1,18 @@
import json import json
import logging
import os import os
import signal
import time import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict
import transformers
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME from .constants import TRAINER_LOG
from .logging import get_logger from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint from .misc import fix_valuehead_checkpoint
@ -33,20 +37,32 @@ class FixValueHeadModelCallback(TrainerCallback):
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, output_dir: str) -> None:
self.runner = runner self.aborted = False
self.in_training = False self.do_train = False
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time() self.start_time = time.time()
self.cur_steps = 0 self.cur_steps = 0
self.max_steps = 0 self.max_steps = max_steps
self.elapsed_time = "" self.elapsed_time = ""
self.remaining_time = "" self.remaining_time = ""
def timing(self): def _timing(self, cur_steps: int) -> None:
cur_time = time.time() cur_time = time.time()
elapsed_time = cur_time - self.start_time elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time)))
@ -54,36 +70,27 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
if state.is_local_process_zero: if args.should_log:
self.in_training = True self.do_train = True
self.start_time = time.time() self._reset(max_steps=state.max_steps)
self.max_steps = state.max_steps
if args.save_on_each_node: if args.should_save:
if not state.is_local_process_zero: os.makedirs(args.output_dir, exist_ok=True)
return self.thread_pool = ThreadPoolExecutor(max_workers=1)
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: if (
logger.warning("Previous log file in this folder will be deleted.") args.should_save
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): ):
r""" logger.warning("Previous trainer log in this folder will be deleted.")
Event called at the end of training. os.remove(os.path.join(args.output_dir, TRAINER_LOG))
"""
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
""" """
if state.is_local_process_zero and self.runner is not None and self.runner.aborted: if self.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@ -91,42 +98,41 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called at the end of a training step. Event called at the end of a training step.
""" """
if state.is_local_process_zero: if args.should_log:
self.cur_steps = state.global_step self._timing(cur_steps=state.global_step)
self.timing()
if self.runner is not None and self.runner.aborted: if self.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after an evaluation phase. Event called at the end of training.
""" """
if state.is_local_process_zero and not self.in_training: self.thread_pool.shutdown(wait=True)
self.cur_steps = 0 self.thread_pool = None
self.max_steps = 0
def on_predict( def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
): ):
r""" r"""
Event called after a successful prediction. Event called after a prediction step.
""" """
if state.is_local_process_zero and not self.in_training: eval_dataloader = kwargs.pop("eval_dataloader", None)
self.cur_steps = 0 if args.should_log and has_length(eval_dataloader) and not self.do_train:
self.max_steps = 0 if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self._timing(cur_steps=self.cur_steps + 1)
def _write_log(self, output_dir: str, logs: Dict[str, Any]):
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs. Event called after logging the last logs, `args.should_log` has been applied.
""" """
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
total_steps=self.max_steps, total_steps=self.max_steps,
@ -141,26 +147,13 @@ class LogCallback(TrainerCallback):
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
) )
if self.runner is not None: logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs:
logger.info( logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 logs["loss"], logs["learning_rate"], logs["epoch"]
) )
) )
os.makedirs(args.output_dir, exist_ok=True) if args.should_save and self.thread_pool is not None:
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: self.thread_pool.submit(self._write_log, args.output_dir, logs)
f.write(json.dumps(logs) + "\n")
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@ -24,8 +24,6 @@ IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"] MLLM_LIST = ["LLaVA1.5"]
@ -34,10 +32,16 @@ MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral
PEFT_METHODS = ["lora"] PEFT_METHODS = ["lora"]
RUNNING_LOG = "running_log.txt"
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict() SUPPORTED_MODELS = OrderedDict()
TRAINER_CONFIG = "trainer_config.yaml"
TRAINER_LOG = "trainer_log.jsonl"
TRAINING_STAGES = { TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft", "Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm", "Reward Modeling": "rm",

View File

@ -1,5 +1,9 @@
import logging import logging
import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor
from .constants import RUNNING_LOG
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
@ -7,19 +11,35 @@ class LoggerHandler(logging.Handler):
Logger handler used in Web UI. Logger handler used in Web UI.
""" """
def __init__(self): def __init__(self, output_dir: str) -> None:
super().__init__() super().__init__()
self.log = "" formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
self.setLevel(logging.INFO)
self.setFormatter(formatter)
def reset(self): os.makedirs(output_dir, exist_ok=True)
self.log = "" self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log):
os.remove(self.running_log)
def emit(self, record): self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f:
f.write(log_entry + "\n\n")
def emit(self, record) -> None:
if record.name == "httpx": if record.name == "httpx":
return return
log_entry = self.format(record) log_entry = self.format(record)
self.log += log_entry self.thread_pool.submit(self._write_log, log_entry)
self.log += "\n\n"
def close(self) -> None:
self.thread_pool.shutdown(wait=True)
return super().close()
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:

View File

@ -1,7 +1,7 @@
import json import json
import math import math
import os import os
from typing import List from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -21,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
EMA implementation according to TensorBoard. EMA implementation according to TensorBoard.
""" """
last = scalars[0] last = scalars[0]
smoothed = list() smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars: for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val smoothed_val = last * weight + (1 - weight) * next_val
@ -30,7 +31,27 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
for log in trainer_log:
if log.get("loss", None):
steps.append(log["current_steps"])
losses.append(log["loss"])
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.constants import TRAINER_CONFIG
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments from .data_args import DataArguments
@ -251,7 +252,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and can_resume_from_checkpoint and can_resume_from_checkpoint
): ):
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: files = os.listdir(training_args.output_dir)
if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None: if last_checkpoint is not None:

View File

@ -1,4 +0,0 @@
from .tuner import export_model, run_exp
__all__ = ["export_model", "run_exp"]

View File

@ -23,9 +23,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []):
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks callbacks.append(LogCallback(training_args.output_dir))
if finetuning_args.stage == "pt": if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
@ -88,7 +88,3 @@ def export_model(args: Optional[Dict[str, Any]] = None):
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception: except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

View File

@ -1,4 +0,0 @@
from .interface import create_ui, create_web_demo
__all__ = ["create_ui", "create_web_demo"]

View File

@ -4,6 +4,7 @@ from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from yaml import safe_dump, safe_load
from ..extras.constants import ( from ..extras.constants import (
DATA_CONFIG, DATA_CONFIG,
@ -29,7 +30,7 @@ DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config" DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves" DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config" USER_CONFIG = "user_config.yaml"
def get_save_dir(*args) -> os.PathLike: def get_save_dir(*args) -> os.PathLike:
@ -47,7 +48,7 @@ def get_save_path(config_path: str) -> os.PathLike:
def load_config() -> Dict[str, Any]: def load_config() -> Dict[str, Any]:
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return safe_load(f)
except Exception: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -60,13 +61,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
user_config["last_model"] = model_name user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f: with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False) safe_dump(user_config, f)
def load_args(config_path: str) -> Optional[Dict[str, Any]]: def load_args(config_path: str) -> Optional[Dict[str, Any]]:
try: try:
with open(get_save_path(config_path), "r", encoding="utf-8") as f: with open(get_save_path(config_path), "r", encoding="utf-8") as f:
return json.load(f) return safe_load(f)
except Exception: except Exception:
return None return None
@ -74,7 +75,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str: def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f: with open(get_save_path(config_path), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False) safe_dump(config_dict, f)
return str(get_save_path(config_path)) return str(get_save_path(config_path))

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict, Generator, List
from ...extras.misc import torch_gc from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ...train import export_model from ...train.tuner import export_model
from ..common import get_save_dir from ..common import get_save_dir
from ..locales import ALERTS from ..locales import ALERTS

View File

@ -245,7 +245,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False) resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False) progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row(): with gr.Row():
output_box = gr.Markdown() output_box = gr.Markdown()
@ -263,14 +263,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir=output_dir, output_dir=output_dir,
config_path=config_path, config_path=config_path,
resume_btn=resume_btn, resume_btn=resume_btn,
process_bar=process_bar, progress_bar=progress_bar,
output_box=output_box, output_box=output_box,
loss_viewer=loss_viewer, loss_viewer=loss_viewer,
) )
) )
input_elems.update({output_dir, config_path}) input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar, loss_viewer] output_elems = [output_box, progress_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)

View File

@ -41,7 +41,7 @@ class Engine:
init_dict["train.dataset"] = {"choices": list_dataset().choices} init_dict["train.dataset"] = {"choices": list_dataset().choices}
init_dict["eval.dataset"] = {"choices": list_dataset().choices} init_dict["eval.dataset"] = {"choices": list_dataset().choices}
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())} init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())} init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())} init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False} init_dict["infer.image_box"] = {"visible": False}
@ -51,7 +51,7 @@ class Engine:
yield self._update_component(init_dict) yield self._update_component(init_dict)
if self.runner.alive and not self.demo_mode and not self.pure_chat: if self.runner.running and not self.demo_mode and not self.pure_chat:
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()} yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
if self.runner.do_train: if self.runner.do_train:
yield self._update_component({"train.resume_btn": {"value": True}}) yield self._update_component({"train.resume_btn": {"value": True}})

View File

@ -68,5 +68,9 @@ def create_web_demo() -> gr.Blocks:
return demo return demo
if __name__ == "__main__": def run_web_ui():
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True) create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
def run_web_demo():
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)

View File

@ -1,22 +1,19 @@
import logging
import os import os
import time import signal
from threading import Thread from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import transformers import psutil
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available from transformers.utils import is_torch_cuda_available
from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
if is_gradio_available(): if is_gradio_available():
@ -34,24 +31,18 @@ class Runner:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
""" Resume """ """ Resume """
self.thread: "Thread" = None self.trainer: Optional["Popen"] = None
self.do_train = True self.do_train = True
self.running_data: Dict["Component", Any] = None self.running_data: Dict["Component", Any] = None
""" State """ """ State """
self.aborted = False self.aborted = False
self.running = False self.running = False
""" Handler """
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
@property
def alive(self) -> bool:
return self.thread is not None
def set_abort(self) -> None: def set_abort(self) -> None:
self.aborted = True self.aborted = True
if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
os.kill(children.pid, signal.SIGABRT)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@ -85,13 +76,11 @@ class Runner:
if not from_preview and not is_torch_cuda_available(): if not from_preview and not is_torch_cuda_available():
gr.Warning(ALERTS["warn_no_cuda"][lang]) gr.Warning(ALERTS["warn_no_cuda"][lang])
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return "" return ""
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
self.thread = None self.trainer = None
self.aborted = False self.aborted = False
self.running = False self.running = False
self.running_data = None self.running_data = None
@ -270,11 +259,12 @@ class Runner:
gr.Warning(error) gr.Warning(error)
yield {output_box: error} yield {output_box: error}
else: else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.do_train, self.running_data = do_train, data self.do_train, self.running_data = do_train, data
self.thread = Thread(target=run_exp, kwargs=run_kwargs) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
self.thread.start() env = deepcopy(os.environ)
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
env["LLAMABOARD_ENABLED"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor() yield from self.monitor()
def preview_train(self, data): def preview_train(self, data):
@ -291,9 +281,6 @@ class Runner:
def monitor(self): def monitor(self):
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.aborted = False
self.running = True
lang = get("top.lang") lang = get("top.lang")
model_name = get("top.model_name") model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type") finetuning_type = get("top.finetuning_type")
@ -301,28 +288,31 @@ class Runner:
output_path = get_save_dir(model_name, finetuning_type, output_dir) output_path = get_save_dir(model_name, finetuning_type, output_dir)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval")) progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.thread is not None and self.thread.is_alive(): while self.trainer is not None:
if self.aborted: if self.aborted:
yield { yield {
output_box: ALERTS["info_aborting"][lang], output_box: ALERTS["info_aborting"][lang],
process_bar: gr.Slider(visible=False), progress_bar: gr.Slider(visible=False),
} }
else: else:
running_log, running_progress, running_loss = get_trainer_info(output_path)
return_dict = { return_dict = {
output_box: self.logger_handler.log, output_box: running_log,
process_bar: update_process_bar(self.trainer_callback), progress_bar: running_progress,
} }
if self.do_train: if self.do_train and running_loss is not None:
plot = gen_plot(output_path) return_dict[loss_viewer] = running_loss
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict yield return_dict
time.sleep(2) try:
self.trainer.wait(2)
self.trainer = None
except TimeoutExpired:
continue
if self.do_train: if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
@ -337,16 +327,11 @@ class Runner:
return_dict = { return_dict = {
output_box: self._finalize(lang, finish_info), output_box: self._finalize(lang, finish_info),
process_bar: gr.Slider(visible=False), progress_bar: gr.Slider(visible=False),
} }
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict yield return_dict
def save_args(self, data): def save_args(self, data: dict):
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True) error = self._initialize(data, do_train=True, from_preview=True)
if error: if error:

View File

@ -1,10 +1,13 @@
import json import json
import os import os
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import smooth from ..extras.ploting import gen_loss_plot
from .locales import ALERTS from .locales import ALERTS
@ -12,30 +15,6 @@ if is_gradio_available():
import gradio as gr import gradio as gr
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
if not callback.max_steps:
return gr.Slider(visible=False)
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return gr.Slider(label=label, value=percentage, visible=True)
def get_time() -> str:
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora": if finetuning_type != "lora":
return gr.Dropdown(value="none", interactive=False) return gr.Dropdown(value="none", interactive=False)
@ -57,14 +36,19 @@ def check_json_schema(text: str, lang: str) -> None:
gr.Warning(ALERTS["err_json_schema"][lang]) gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: Dict[str, Any]) -> str:
args.pop("disable_tqdm", None) args.pop("disable_tqdm", None)
args["plot_loss"] = args.get("do_train", None) args["plot_loss"] = args.get("do_train", None)
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)] cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
for k, v in args.items(): for k, v in clean_cmd(args).items():
if v is not None and v is not False and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v))) cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines) cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text) cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text return cmd_text
@ -76,29 +60,49 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result) return "```json\n{}\n```\n".format(result)
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]: def get_time() -> str:
log_file = os.path.join(output_path, "trainer_log.jsonl") return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
if not os.path.isfile(log_file) or not is_matplotlib_available():
return
plt.close("all")
plt.switch_backend("agg") def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
fig = plt.figure() running_log = ""
ax = fig.add_subplot(111) running_progress = gr.Slider(visible=False)
steps, losses = [], [] running_loss = None
with open(log_file, "r", encoding="utf-8") as f:
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f:
for line in f: for line in f:
log_info: Dict[str, Any] = json.loads(line) trainer_log.append(json.loads(line))
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0: if len(trainer_log) != 0:
return latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") if is_matplotlib_available():
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") running_loss = gr.Plot(gen_loss_plot(trainer_log))
ax.legend()
ax.set_xlabel("step") return running_log, running_progress, running_loss
ax.set_ylabel("loss")
return fig
def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)

View File

@ -1,4 +1,4 @@
from llmtuner import run_exp from llmtuner.train.tuner import run_exp
def main(): def main():
@ -7,7 +7,7 @@ def main():
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
main() run_exp()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,9 +0,0 @@
from llmtuner import create_ui
def main():
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
from llmtuner import create_web_demo
def main():
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
if __name__ == "__main__":
main()