This commit is contained in:
zhaonx96 2024-05-06 10:09:00 +08:00
commit 1abd55dd59
85 changed files with 1276 additions and 747 deletions

View File

@ -11,4 +11,4 @@ RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
EXPOSE 7860
CMD [ "python", "src/train_web.py" ]
CMD [ "llamafactory-cli webui" ]

View File

@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-34-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-42-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -339,16 +339,17 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT]
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#train-with-command-line-interface) for distributed training.
#### Use local environment
```bash
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
python src/train_web.py # or python -m llmtuner.webui.interface
llamafactory-cli webui
```
> [!TIP]
> To modify the default setting in the LLaMA Board GUI, you can use environment variables, e.g., `export CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_PORT=7860 GRADIO_SHARE=False` (use `set` command on Windows OS).
<details><summary>For Alibaba Cloud users</summary>
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
@ -392,12 +393,13 @@ docker compose -f ./docker-compose.yml up -d
See [examples/README.md](examples/README.md) for usage.
Use `python src/train_bash.py -h` to display arguments description.
> [!TIP]
> Use `llamafactory-cli train -h` to display arguments description.
### Deploy with OpenAI-style API and vLLM
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \
--infer_backend vllm \
@ -441,6 +443,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
@ -448,7 +451,14 @@ If you have a project that should be incorporated, please contact via email or c
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.

View File

@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-34-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-42-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -339,16 +339,17 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#利用命令行接口训练)来进行多 GPU 分布式训练。
#### 使用本地环境
```bash
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
python src/train_web.py # 或 python -m llmtuner.webui.interface
llamafactory-cli webui
```
> [!TIP]
> 您可以使用环境变量来修改 LLaMA Board 可视化界面的默认设置,例如 `export CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_PORT=7860 GRADIO_SHARE=False`Windows 系统可使用 `set` 指令)。
<details><summary>阿里云用户指南</summary>
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
@ -392,12 +393,13 @@ docker compose -f ./docker-compose.yml up -d
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
您可以执行 `python src/train_bash.py -h` 来查看参数文档。
> [!TIP]
> 您可以执行 `llamafactory-cli train -h` 来查看参数文档。
### 利用 vLLM 部署 OpenAI API
```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--template llama3 \
--infer_backend vllm \
@ -441,6 +443,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
@ -448,7 +451,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

After

Width:  |  Height:  |  Size: 123 KiB

View File

@ -1,4 +1,4 @@
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
If you are using a custom dataset, please add your **dataset description** to `dataset_info.json` according to the following format. We also provide several examples in the next section.
```json
"dataset_name": {
@ -33,7 +33,7 @@ If you are using a custom dataset, please provide your dataset definition in the
}
```
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
After that, you can load the custom dataset by specifying `--dataset dataset_name`.
----
@ -54,10 +54,11 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
Regarding the above dataset, the description in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
@ -70,28 +71,60 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
For the pre-training datasets, only the `prompt` column will be used for training.
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
For the **pre-training datasets**, only the `prompt` column will be used for training, for example:
```json
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
[
{"text": "document"},
{"text": "document"}
]
```
Regarding the above dataset, the description in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "text"
}
}
```
Remember to set `"ranking": true` for the preference datasets.
For the **preference datasets**, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
```json
[
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
}
]
```
Regarding the above dataset, the description in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"ranking": true,
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
}
}
```
----
The dataset in sharegpt format should follow the below format:
The dataset in **sharegpt** format should follow the below format:
```json
[
@ -112,10 +145,12 @@ The dataset in sharegpt format should follow the below format:
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
Regarding the above dataset, the description in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"system": "system",
@ -132,4 +167,46 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
We also supports the dataset in the **openai** format:
```json
[
{
"messages": [
{
"role": "system",
"content": "system prompt (optional)"
},
{
"role": "user",
"content": "user instruction"
},
{
"role": "assistant",
"content": "model response"
}
]
}
]
```
Regarding the above dataset, the description in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
"system_tag": "system"
}
}
```
Pre-training datasets and preference datasets are **incompatible** with the sharegpt format yet.

View File

@ -1,4 +1,4 @@
如果您使用自定义数据集,请务必`dataset_info.json` 文件中按照以下格式提供数据集定义
如果您使用自定义数据集,请务必按照以下格式在 `dataset_info.json` 文件中添加**数据集描述**。我们在下面也提供了一些例子
```json
"数据集名称": {
@ -33,7 +33,7 @@
}
```
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
然后,可通过使用 `--dataset 数据集名称` 参数加载自定义数据集。
----
@ -54,10 +54,11 @@
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
对于上述格式的数据,`dataset_info.json` 中的描述应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
@ -70,28 +71,60 @@
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
对于**预训练数据集**,仅 `prompt` 列中的内容会用于模型训练,例如:
```json
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"优质回答",
"劣质回答"
]
[
{"text": "document"},
{"text": "document"}
]
```
对于上述格式的数据,`dataset_info.json` 中的描述应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "text"
}
}
```
添加偏好数据集需要额外指定 `"ranking": true`
对于**偏好数据集**`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
```json
[
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"优质回答",
"劣质回答"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的描述应为:
```json
"数据集名称": {
"file_name": "data.json",
"ranking": true,
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
}
}
```
----
而 sharegpt 格式的数据集按照以下方式组织:
**sharegpt** 格式的数据集按照以下方式组织:
```json
[
@ -112,10 +145,12 @@
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
对于上述格式的数据,`dataset_info.json` 中的描述应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"system": "system",
@ -132,4 +167,46 @@
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
我们同样支持 **openai** 格式的数据集:
```json
[
{
"messages": [
{
"role": "system",
"content": "系统提示词(选填)"
},
{
"role": "user",
"content": "用户指令"
},
{
"role": "assistant",
"content": "模型回答"
}
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的描述应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
"system_tag": "system"
}
}
```
预训练数据集和偏好数据集**尚不支持** sharegpt 格式。

View File

@ -19,7 +19,7 @@ import pandas as pd
_CITATION = """\
@article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
journal={arXiv preprint arXiv:2305.08322},
year={2023}
@ -133,25 +133,19 @@ class Ceval(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(
data_dir, "test", f"{task_name}_test.csv"
),
"filepath": os.path.join(data_dir, "test", f"{task_name}_test.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(
data_dir, "val", f"{task_name}_val.csv"
),
"filepath": os.path.join(data_dir, "val", f"{task_name}_val.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(
data_dir, "dev", f"{task_name}_dev.csv"
),
"filepath": os.path.join(data_dir, "dev", f"{task_name}_dev.csv"),
},
),
]

View File

@ -37,73 +37,73 @@ _LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 Internatio
_URL = "cmmlu.zip"
task_list = [
'agronomy',
'anatomy',
'ancient_chinese',
'arts',
'astronomy',
'business_ethics',
'chinese_civil_service_exam',
'chinese_driving_rule',
'chinese_food_culture',
'chinese_foreign_policy',
'chinese_history',
'chinese_literature',
'chinese_teacher_qualification',
'clinical_knowledge',
'college_actuarial_science',
'college_education',
'college_engineering_hydrology',
'college_law',
'college_mathematics',
'college_medical_statistics',
'college_medicine',
'computer_science',
'computer_security',
'conceptual_physics',
'construction_project_management',
'economics',
'education',
'electrical_engineering',
'elementary_chinese',
'elementary_commonsense',
'elementary_information_and_technology',
'elementary_mathematics',
'ethnology',
'food_science',
'genetics',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_geography',
'high_school_mathematics',
'high_school_physics',
'high_school_politics',
'human_sexuality',
'international_law',
'journalism',
'jurisprudence',
'legal_and_moral_basis',
'logical',
'machine_learning',
'management',
'marketing',
'marxist_theory',
'modern_chinese',
'nutrition',
'philosophy',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_study',
'sociology',
'sports_science',
'traditional_chinese_medicine',
'virology',
'world_history',
'world_religions',
"agronomy",
"anatomy",
"ancient_chinese",
"arts",
"astronomy",
"business_ethics",
"chinese_civil_service_exam",
"chinese_driving_rule",
"chinese_food_culture",
"chinese_foreign_policy",
"chinese_history",
"chinese_literature",
"chinese_teacher_qualification",
"clinical_knowledge",
"college_actuarial_science",
"college_education",
"college_engineering_hydrology",
"college_law",
"college_mathematics",
"college_medical_statistics",
"college_medicine",
"computer_science",
"computer_security",
"conceptual_physics",
"construction_project_management",
"economics",
"education",
"electrical_engineering",
"elementary_chinese",
"elementary_commonsense",
"elementary_information_and_technology",
"elementary_mathematics",
"ethnology",
"food_science",
"genetics",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_geography",
"high_school_mathematics",
"high_school_physics",
"high_school_politics",
"human_sexuality",
"international_law",
"journalism",
"jurisprudence",
"legal_and_moral_basis",
"logical",
"machine_learning",
"management",
"marketing",
"marxist_theory",
"modern_chinese",
"nutrition",
"philosophy",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_study",
"sociology",
"sports_science",
"traditional_chinese_medicine",
"virology",
"world_history",
"world_religions",
]

View File

@ -136,25 +136,19 @@ class MMLU(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "test", f"{task_name}_test.csv"
),
"filepath": os.path.join(data_dir, "data", "test", f"{task_name}_test.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "val", f"{task_name}_val.csv"
),
"filepath": os.path.join(data_dir, "data", "val", f"{task_name}_val.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(
data_dir, "data", "dev", f"{task_name}_dev.csv"
),
"filepath": os.path.join(data_dir, "data", "dev", f"{task_name}_dev.csv"),
},
),
]

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
@ -10,7 +10,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--finetuning_type full \
--use_badam \
--badam_switch_mode descending \
--badam_switch_block_every 50 \
--badam_switch_interval 50 \
--badam_verbose 2 \
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
--overwrite_cache \

View File

@ -7,7 +7,7 @@ pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file ../../accelerate/fsdp_config.yaml \
../../../src/train_bash.py \
../../../src/train.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-70b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path ../../../models/llama2-7b-pro \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -6,7 +6,7 @@ python -m torch.distributed.run \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
../../src/train_bash.py \
../../src/train.py \
--deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \
--do_train \

View File

@ -2,7 +2,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
../../src/train.py \
--stage sft \
--do_predict \
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \

View File

@ -1,6 +1,6 @@
#!/bin/bash
deepspeed --num_gpus 4 ../../src/train_bash.py \
deepspeed --num_gpus 4 ../../src/train.py \
--deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \
--do_train \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template fewshot \

View File

@ -1,7 +1,7 @@
#!/bin/bash
# add `--visual_inputs True` to load MLLM
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \

View File

@ -1,6 +1,7 @@
#!/bin/bash
# ZeRO-3 enables weight sharding on multiple GPUs
deepspeed --num_gpus 4 ../../src/train_bash.py \
deepspeed --num_gpus 4 ../../src/train.py \
--deepspeed ../deepspeed/ds_z3_config.json \
--stage sft \
--do_train \

View File

@ -3,7 +3,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/master_config.yaml \
../../src/train_bash.py \
../../src/train.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -2,7 +2,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file ../accelerate/single_config.yaml \
../../src/train_bash.py \
../../src/train.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage dpo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage orpo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage ppo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_predict \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,7 @@
#!/bin/bash
# use `--tokenized_path` in training script to load data
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES= llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage pt \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage rm \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \

View File

@ -1,7 +1,7 @@
#!/bin/bash
# DO NOT use quantized model or quantization_bit when merging lora weights
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
--template default \

View File

@ -1,7 +1,7 @@
#!/bin/bash
# NEED TO run `merge.sh` before using this script
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
--model_name_or_path ../../models/llama2-7b-sft \
--template default \
--export_dir ../../models/llama2-7b-sft-int4 \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path TheBloke/Llama-2-7B-AWQ \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \

View File

@ -1,6 +1,6 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
--stage sft \
--do_train \
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \

View File

@ -16,3 +16,4 @@ sse-starlette
matplotlib
fire
packaging
pyyaml

View File

@ -3,24 +3,22 @@
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
from typing import Optional
import fire
import torch
from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
from llmtuner.chat import ChatModel
def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
flash_attn: Optional[bool] = False,
batch_size: int = 1,
seq_length: int = 256,
flash_attn: str = "auto",
):
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)

View File

@ -4,7 +4,7 @@
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import math
from typing import Optional
from typing import Literal
import fire
import torch
@ -25,12 +25,12 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Optional[str] = "sft",
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
):
model_args, data_args, training_args, _, _ = get_train_args(
dict(
@ -54,9 +54,7 @@ def calculate_lr(
else:
raise NotImplementedError
dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
)
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()

116
scripts/cal_ppl.py Normal file
View File

@ -0,0 +1,116 @@
# coding=utf-8
# Calculates the ppl on the dataset of the pre-trained models.
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
import json
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence
import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model, load_tokenizer
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
chosen_features = []
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
attention_mask = [1] * (prompt_len + answer_len)
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
return super().__call__(chosen_features)
def cal_ppl(
model_name_or_path: str,
save_name: str,
batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
train_on_prompt=train_on_prompt,
output_dir="dummy_dir",
overwrite_cache=True,
)
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
elif stage == "rm":
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for batch in tqdm(dataloader):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
perplexities.extend(sentence_logps.exp().tolist())
with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
print("Perplexities have been saved at {}.".format(save_name))
if __name__ == "__main__":
fire.Fire(cal_ppl)

View File

@ -3,7 +3,6 @@
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
from collections import defaultdict
from typing import Optional
import fire
from tqdm import tqdm
@ -15,10 +14,10 @@ from llmtuner.model import load_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
interval: Optional[int] = 1000,
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
interval: int = 1000,
):
model_args, data_args, training_args, _, _ = get_train_args(
dict(

View File

@ -52,6 +52,7 @@ def main():
python_requires=">=3.8.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",

View File

@ -1,16 +0,0 @@
import os
import uvicorn
from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
if __name__ == "__main__":
main()

View File

@ -1,49 +0,0 @@
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
def main():
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
from llmtuner import Evaluator
def main():
Evaluator().eval()
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
from llmtuner import export_model
def main():
export_model()
if __name__ == "__main__":
main()

View File

@ -1,11 +1,3 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams
from .api import create_app
from .chat import ChatModel
from .eval import Evaluator
from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.7.0"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
__version__ = "0.7.1.dev0"

View File

@ -1,4 +0,0 @@
from .app import create_app
__all__ = ["create_app"]

View File

@ -1,36 +1,29 @@
import json
import os
from contextlib import asynccontextmanager
from typing import Any, Dict, Sequence
from pydantic import BaseModel
from typing import Annotated, Optional
from ..chat import ChatModel
from ..data import Role as DataRole
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .chat import (
create_chat_completion_response,
create_score_evaluation_response,
create_stream_chat_completion_response,
)
from .protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
Finish,
Function,
FunctionCall,
ModelCard,
ModelList,
Role,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
if is_starlette_available():
@ -47,23 +40,8 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
torch_gc()
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@ -71,162 +49,58 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY", None)
security = HTTPBearer(auto_error=False)
role_mapping = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
@app.get("/v1/models", response_model=ModelList)
@app.get(
"/v1/models",
response_model=ModelList,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
@app.post(
"/v1/chat/completions",
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = ""
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
if request.stream:
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(input_messages, system, tools, request)
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return await create_chat_completion_response(request, chat_model)
responses = await chat_model.achat(
input_messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
stop=request.stop
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
)
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def stream_chat_completion(
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
):
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
async for new_token in chat_model.astream_chat(
messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
stop=request.stop
):
if len(new_token) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk)
yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
@app.post(
"/v1/score/evaluation",
response_model=ScoreEvaluationResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)
return await create_score_evaluation_response(request, chat_model)
return app
if __name__ == "__main__":
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)

177
src/llmtuner/api/chat.py Normal file
View File

@ -0,0 +1,177 @@
import json
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.packages import is_fastapi_availble
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
Function,
FunctionCall,
Role,
ScoreEvaluationResponse,
)
if is_fastapi_availble():
from fastapi import HTTPException, status
if TYPE_CHECKING:
from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = ""
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
return input_messages, system, tools
def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text)
else:
result = response.response_text
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
)
async for new_token in chat_model.astream_chat(
input_messages,
system,
tools,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
)
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)

View File

@ -0,0 +1,20 @@
import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)

View File

@ -51,7 +51,7 @@ class FunctionAvailable(BaseModel):
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
id: str
type: Literal["function"] = "function"
function: Function
@ -87,7 +87,7 @@ class ChatCompletionResponseChoice(BaseModel):
finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
@ -100,7 +100,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
@ -109,11 +109,11 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel):
id: Literal["chatcmpl-default"] = "chatcmpl-default"
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
choices: List[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
@ -123,7 +123,7 @@ class ScoreEvaluationRequest(BaseModel):
class ScoreEvaluationResponse(BaseModel):
id: Literal["scoreeval-default"] = "scoreeval-default"
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]

View File

@ -2,6 +2,7 @@ import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine
@ -95,3 +96,45 @@ class ChatModel:
**input_kwargs,
) -> List[float]:
return await self.engine.get_scores(batch_input, **input_kwargs)
def run_chat() -> None:
try:
import platform
if platform.system() != "Windows":
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})

59
src/llmtuner/cli.py Normal file
View File

@ -0,0 +1,59 @@
import sys
from enum import Enum, unique
from . import __version__
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = """
Usage:
llamafactory-cli api -h: launch an API server
llamafactory-cli chat -h: launch a chat interface in CLI
llamafactory-cli eval -h: do evaluation
llamafactory-cli export -h: merge LoRA adapters and export model
llamafactory-cli train -h: do training
llamafactory-cli webchat -h: launch a chat interface in Web UI
llamafactory-cli webui: launch LlamaBoard
llamafactory-cli version: show version info
"""
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VERSION = "version"
HELP = "help"
def main():
command = sys.argv.pop(1)
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VERSION:
print("Welcome to LLaMA Factory, version {}".format(__version__))
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}".format(command))

View File

@ -1,4 +0,0 @@
from .evaluator import Evaluator
__all__ = ["Evaluator"]

View File

@ -118,6 +118,5 @@ class Evaluator:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()
def run_eval() -> None:
Evaluator().eval()

View File

@ -1,14 +1,19 @@
import json
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
@ -33,57 +38,92 @@ class FixValueHeadModelCallback(TrainerCallback):
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
self.start_time = time.time()
def __init__(self, output_dir: str) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
self.aborted = False
self.do_train = False
""" Web UI """
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def timing(self):
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def _create_thread_pool(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _close_thread_pool(self) -> None:
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
self._close_thread_pool()
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@ -91,42 +131,30 @@ class LogCallback(TrainerCallback):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
self._close_thread_pool()
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
):
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
self._close_thread_pool()
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if not args.should_save:
return
self._timing(cur_steps=state.global_step)
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
@ -141,16 +169,16 @@ class LogCallback(TrainerCallback):
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
if self.runner is not None:
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
logs["loss"], logs["learning_rate"], logs["epoch"]
)
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
@ -158,9 +186,28 @@ class LogCallback(TrainerCallback):
r"""
Event called after a prediction step.
"""
if self.do_train:
return
if self.aborted:
sys.exit(0)
if not args.should_save:
return
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if has_length(eval_dataloader):
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()
self._reset(max_steps=len(eval_dataloader))
self._create_thread_pool(output_dir=args.output_dir)
self._timing(cur_steps=self.cur_steps + 1)
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@ -24,8 +24,6 @@ IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"]
@ -34,10 +32,16 @@ MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral
PEFT_METHODS = ["lora"]
RUNNING_LOG = "running_log.txt"
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINER_CONFIG = "trainer_config.yaml"
TRAINER_LOG = "trainer_log.jsonl"
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",

View File

@ -1,5 +1,9 @@
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from .constants import RUNNING_LOG
class LoggerHandler(logging.Handler):
@ -7,19 +11,35 @@ class LoggerHandler(logging.Handler):
Logger handler used in Web UI.
"""
def __init__(self):
def __init__(self, output_dir: str) -> None:
super().__init__()
self.log = ""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
self.setLevel(logging.INFO)
self.setFormatter(formatter)
def reset(self):
self.log = ""
os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log):
os.remove(self.running_log)
def emit(self, record):
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f:
f.write(log_entry + "\n\n")
def emit(self, record) -> None:
if record.name == "httpx":
return
log_entry = self.format(record)
self.log += log_entry
self.log += "\n\n"
self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None:
self.thread_pool.shutdown(wait=True)
return super().close()
def get_logger(name: str) -> logging.Logger:

View File

@ -1,7 +1,7 @@
import json
import math
import os
from typing import List
from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
@ -21,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
@ -30,7 +31,27 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
for log in trainer_log:
if log.get("loss", None):
steps.append(log["current_steps"])
losses.append(log["loss"])
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)

View File

@ -221,16 +221,18 @@ class BAdamArgument:
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio: float = field(
default=0.0,
default=0.05,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
@ -308,6 +310,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")

View File

@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.constants import TRAINER_CONFIG
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
@ -251,7 +252,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
files = os.listdir(training_args.output_dir)
if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:

View File

@ -1,4 +0,0 @@
from .tuner import export_model, run_exp
__all__ = ["export_model", "run_exp"]

View File

@ -165,13 +165,13 @@ class CustomDPOTrainer(DPOTrainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
return losses.mean(), metrics

View File

@ -113,15 +113,15 @@ class CustomORPOTrainer(DPOTrainer):
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean()
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().mean().cpu()
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().mean().cpu()
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().mean().cpu()
return batch_loss, metrics

View File

@ -23,9 +23,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
callbacks.append(LogCallback(training_args.output_dir))
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
@ -43,7 +43,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None):
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None:
@ -88,7 +88,3 @@ def export_model(args: Optional[Dict[str, Any]] = None):
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

View File

@ -317,14 +317,14 @@ def _create_badam_optimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_block_every,
switch_block_every=finetuning_args.badam_switch_interval,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
f"switch block every {finetuning_args.badam_switch_interval} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)

View File

@ -1,4 +0,0 @@
from .interface import create_ui, create_web_demo
__all__ = ["create_ui", "create_web_demo"]

View File

@ -4,6 +4,7 @@ from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from yaml import safe_dump, safe_load
from ..extras.constants import (
DATA_CONFIG,
@ -16,6 +17,7 @@ from ..extras.constants import (
TRAINING_STAGES,
DownloadSource,
)
from ..extras.logging import get_logger
from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available
@ -24,12 +26,15 @@ if is_gradio_available():
import gradio as gr
logger = get_logger(__name__)
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
USER_CONFIG = "user_config.yaml"
def get_save_dir(*args) -> os.PathLike:
@ -47,7 +52,7 @@ def get_save_path(config_path: str) -> os.PathLike:
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -60,13 +65,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)
safe_dump(user_config, f)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
try:
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
return json.load(f)
return safe_load(f)
except Exception:
return None
@ -74,7 +79,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
safe_dump(config_dict, f)
return str(get_save_path(config_path))
@ -127,11 +132,15 @@ def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
if dataset_dir == "ONLINE":
logger.info("dataset_dir is ONLINE, using online dataset.")
return {}
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f)
except Exception as err:
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
return {}

View File

@ -36,9 +36,9 @@ def create_chat_box(
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(8, 4096, value=512, step=1)
top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])

View File

@ -21,25 +21,25 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
@ -52,19 +52,19 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
output_elems = [output_box, process_bar]
output_elems = [output_box, progress_bar]
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
process_bar=process_bar,
progress_bar=progress_bar,
output_box=output_box,
)
)

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict, Generator, List
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train import export_model
from ...train.tuner import export_model
from ..common import get_save_dir
from ..locales import ALERTS
@ -85,7 +85,7 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")

View File

@ -27,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=4)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
@ -52,10 +52,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
@ -71,10 +71,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as extra_tab:
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
optim = gr.Textbox(value="adamw_torch")
with gr.Row():
@ -124,7 +124,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as freeze_tab:
with gr.Row():
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
num_layer_trainable = gr.Slider(minimum=1, maximum=128, value=2, step=1)
name_module_trainable = gr.Textbox(value="all")
input_elems.update({num_layer_trainable, name_module_trainable})
@ -136,10 +136,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
create_new_adapter = gr.Checkbox()
with gr.Row():
@ -180,9 +180,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
@ -193,9 +193,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox()
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
@ -210,6 +210,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
elem_dict.update(
dict(
badam_tab=badam_tab,
use_badam=use_badam,
badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode,
badam_switch_interval=badam_switch_interval,
badam_update_ratio=badam_update_ratio,
)
)
with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
@ -225,7 +245,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
@ -243,14 +263,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir=output_dir,
config_path=config_path,
resume_btn=resume_btn,
process_bar=process_bar,
progress_bar=progress_bar,
output_box=output_box,
loss_viewer=loss_viewer,
)
)
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar, loss_viewer]
output_elems = [output_box, progress_bar, loss_viewer]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)

View File

@ -41,7 +41,7 @@ class Engine:
init_dict["train.dataset"] = {"choices": list_dataset().choices}
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False}
@ -51,7 +51,7 @@ class Engine:
yield self._update_component(init_dict)
if self.runner.alive and not self.demo_mode and not self.pure_chat:
if self.runner.running and not self.demo_mode and not self.pure_chat:
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
if self.runner.do_train:
yield self._update_component({"train.resume_btn": {"value": True}})

View File

@ -68,5 +68,9 @@ def create_web_demo() -> gr.Blocks:
return demo
if __name__ == "__main__":
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
def run_web_ui() -> None:
create_ui().queue().launch()
def run_web_demo() -> None:
create_web_demo().queue().launch()

View File

@ -891,6 +891,87 @@ LOCALES = {
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
},
},
"badam_tab": {
"en": {
"label": "BAdam configurations",
},
"ru": {
"label": "Конфигурации BAdam",
},
"zh": {
"label": "BAdam 参数设置",
},
},
"use_badam": {
"en": {
"label": "Use BAdam",
"info": "Enable the BAdam optimizer.",
},
"ru": {
"label": "Использовать BAdam",
"info": "Включите оптимизатор BAdam.",
},
"zh": {
"label": "使用 BAdam",
"info": "使用 BAdam 优化器。",
},
},
"badam_mode": {
"en": {
"label": "BAdam mode",
"info": "Whether to use layer-wise or ratio-wise BAdam optimizer.",
},
"ru": {
"label": "Режим BAdam",
"info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.",
},
"zh": {
"label": "BAdam 模式",
"info": "使用 layer-wise 或 ratio-wise BAdam 优化器。",
},
},
"badam_switch_mode": {
"en": {
"label": "Switch mode",
"info": "The strategy of picking block to update for layer-wise BAdam.",
},
"ru": {
"label": "Режим переключения",
"info": "Стратегия выбора блока для обновления для послойного BAdam.",
},
"zh": {
"label": "切换策略",
"info": "Layer-wise BAdam 优化器的块切换策略。",
},
},
"badam_switch_interval": {
"en": {
"label": "Switch interval",
"info": "Number of steps to update the block for layer-wise BAdam.",
},
"ru": {
"label": "Интервал переключения",
"info": "количество шагов для обновления блока для пошагового BAdam.",
},
"zh": {
"label": "切换频率",
"info": "Layer-wise BAdam 优化器的块切换频率。",
},
},
"badam_update_ratio": {
"en": {
"label": "Update ratio",
"info": "The ratio of the update for ratio-wise BAdam.",
},
"ru": {
"label": "Коэффициент обновления",
"info": "Коэффициент обновления для BAdam с учётом соотношений.",
},
"zh": {
"label": "Block 更新比例",
"info": "Ratio-wise BAdam 优化器的更新比例。",
},
},
"cmd_preview_btn": {
"en": {
"value": "Preview command",
@ -1368,7 +1449,7 @@ ALERTS = {
"info_aborting": {
"en": "Aborted, wait for terminating...",
"ru": "Прервано, ожидание завершения...",
"zh": "训练中断,正在等待线程结束……",
"zh": "训练中断,正在等待程结束……",
},
"info_aborted": {
"en": "Ready.",

View File

@ -1,22 +1,19 @@
import logging
import os
import time
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator
import signal
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import transformers
import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available
from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_gradio_available
from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
if is_gradio_available():
@ -34,24 +31,18 @@ class Runner:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.thread: "Thread" = None
self.trainer: Optional["Popen"] = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
""" State """
self.aborted = False
self.running = False
""" Handler """
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
@property
def alive(self) -> bool:
return self.thread is not None
def set_abort(self) -> None:
self.aborted = True
if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
os.kill(children.pid, signal.SIGABRT)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@ -85,13 +76,11 @@ class Runner:
if not from_preview and not is_torch_cuda_available():
gr.Warning(ALERTS["warn_no_cuda"][lang])
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
self.thread = None
self.trainer = None
self.aborted = False
self.running = False
self.running_data = None
@ -147,12 +136,12 @@ class Runner:
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
)
args["disable_tqdm"] = True
if args["finetuning_type"] == "freeze":
args["num_layer_trainable"] = get("train.num_layer_trainable")
@ -198,6 +187,12 @@ class Runner:
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
@ -237,7 +232,6 @@ class Runner:
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
)
args["disable_tqdm"] = True
if get("eval.predict"):
args["do_predict"] = True
@ -263,11 +257,12 @@ class Runner:
gr.Warning(error)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.do_train, self.running_data = do_train, data
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ)
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
env["LLAMABOARD_ENABLED"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor()
def preview_train(self, data):
@ -283,10 +278,10 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self):
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.aborted = False
self.running = True
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
@ -294,28 +289,31 @@ class Runner:
output_path = get_save_dir(model_name, finetuning_type, output_dir)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.thread is not None and self.thread.is_alive():
while self.trainer is not None:
if self.aborted:
yield {
output_box: ALERTS["info_aborting"][lang],
process_bar: gr.Slider(visible=False),
progress_bar: gr.Slider(visible=False),
}
else:
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
return_dict = {
output_box: self.logger_handler.log,
process_bar: update_process_bar(self.trainer_callback),
output_box: running_log,
progress_bar: running_progress,
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
if running_loss is not None:
return_dict[loss_viewer] = running_loss
yield return_dict
time.sleep(2)
try:
self.trainer.wait(2)
self.trainer = None
except TimeoutExpired:
continue
if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
@ -330,16 +328,11 @@ class Runner:
return_dict = {
output_box: self._finalize(lang, finish_info),
process_bar: gr.Slider(visible=False),
progress_bar: gr.Slider(visible=False),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict
def save_args(self, data):
def save_args(self, data: dict):
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:

View File

@ -1,10 +1,13 @@
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import smooth
from ..extras.ploting import gen_loss_plot
from .locales import ALERTS
@ -12,30 +15,6 @@ if is_gradio_available():
import gradio as gr
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
if not callback.max_steps:
return gr.Slider(visible=False)
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return gr.Slider(label=label, value=percentage, visible=True)
def get_time() -> str:
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
return gr.Dropdown(value="none", interactive=False)
@ -57,14 +36,18 @@ def check_json_schema(text: str, lang: str) -> None:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
args.pop("disable_tqdm", None)
args["plot_loss"] = args.get("do_train", None)
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
for k, v in args.items():
if v is not None and v is not False and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
@ -76,29 +59,49 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
log_file = os.path.join(output_path, "trainer_log.jsonl")
if not os.path.isfile(log_file) or not is_matplotlib_available():
return
def get_time() -> str:
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info: Dict[str, Any] = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)

View File

@ -1,4 +1,4 @@
from llmtuner import run_exp
from llmtuner.train.tuner import run_exp
def main():
@ -7,7 +7,7 @@ def main():
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
run_exp()
if __name__ == "__main__":

View File

@ -1,9 +0,0 @@
from llmtuner import create_ui
def main():
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
if __name__ == "__main__":
main()

View File

@ -1,9 +0,0 @@
from llmtuner import create_web_demo
def main():
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
if __name__ == "__main__":
main()