forked from p04798526/LLaMA-Factory-Mirror
Merge branch 'main' of https://github.com/zhaonx/LLaMA-Factory into dev
This commit is contained in:
commit
1abd55dd59
|
@ -11,4 +11,4 @@ RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
|
|||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||
EXPOSE 7860
|
||||
|
||||
CMD [ "python", "src/train_web.py" ]
|
||||
CMD [ "llamafactory-cli webui" ]
|
||||
|
|
24
README.md
24
README.md
|
@ -5,7 +5,7 @@
|
|||
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
|
||||
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
|
||||
[![Citation](https://img.shields.io/badge/citation-34-green)](#projects-using-llama-factory)
|
||||
[![Citation](https://img.shields.io/badge/citation-42-green)](#projects-using-llama-factory)
|
||||
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
||||
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
||||
|
@ -339,16 +339,17 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
|||
### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
|
||||
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#train-with-command-line-interface) for distributed training.
|
||||
|
||||
#### Use local environment
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
|
||||
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
|
||||
python src/train_web.py # or python -m llmtuner.webui.interface
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> To modify the default setting in the LLaMA Board GUI, you can use environment variables, e.g., `export CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_PORT=7860 GRADIO_SHARE=False` (use `set` command on Windows OS).
|
||||
|
||||
<details><summary>For Alibaba Cloud users</summary>
|
||||
|
||||
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
|
||||
|
@ -392,12 +393,13 @@ docker compose -f ./docker-compose.yml up -d
|
|||
|
||||
See [examples/README.md](examples/README.md) for usage.
|
||||
|
||||
Use `python src/train_bash.py -h` to display arguments description.
|
||||
> [!TIP]
|
||||
> Use `llamafactory-cli train -h` to display arguments description.
|
||||
|
||||
### Deploy with OpenAI-style API and vLLM
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
|
||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--template llama3 \
|
||||
--infer_backend vllm \
|
||||
|
@ -441,6 +443,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
|
@ -448,7 +451,14 @@ If you have a project that should be incorporated, please contact via email or c
|
|||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
|
|
24
README_zh.md
24
README_zh.md
|
@ -5,7 +5,7 @@
|
|||
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
|
||||
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
|
||||
[![Citation](https://img.shields.io/badge/citation-34-green)](#使用了-llama-factory-的项目)
|
||||
[![Citation](https://img.shields.io/badge/citation-42-green)](#使用了-llama-factory-的项目)
|
||||
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
|
||||
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
|
||||
|
@ -339,16 +339,17 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||
### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#利用命令行接口训练)来进行多 GPU 分布式训练。
|
||||
|
||||
#### 使用本地环境
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
|
||||
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
|
||||
python src/train_web.py # 或 python -m llmtuner.webui.interface
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> 您可以使用环境变量来修改 LLaMA Board 可视化界面的默认设置,例如 `export CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_PORT=7860 GRADIO_SHARE=False`(Windows 系统可使用 `set` 指令)。
|
||||
|
||||
<details><summary>阿里云用户指南</summary>
|
||||
|
||||
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
|
||||
|
@ -392,12 +393,13 @@ docker compose -f ./docker-compose.yml up -d
|
|||
|
||||
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
|
||||
|
||||
您可以执行 `python src/train_bash.py -h` 来查看参数文档。
|
||||
> [!TIP]
|
||||
> 您可以执行 `llamafactory-cli train -h` 来查看参数文档。
|
||||
|
||||
### 利用 vLLM 部署 OpenAI API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
|
||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--template llama3 \
|
||||
--infer_backend vllm \
|
||||
|
@ -441,6 +443,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
|
@ -448,7 +451,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 186 KiB After Width: | Height: | Size: 123 KiB |
113
data/README.md
113
data/README.md
|
@ -1,4 +1,4 @@
|
|||
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
|
||||
If you are using a custom dataset, please add your **dataset description** to `dataset_info.json` according to the following format. We also provide several examples in the next section.
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
|
@ -33,7 +33,7 @@ If you are using a custom dataset, please provide your dataset definition in the
|
|||
}
|
||||
```
|
||||
|
||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
||||
After that, you can load the custom dataset by specifying `--dataset dataset_name`.
|
||||
|
||||
----
|
||||
|
||||
|
@ -54,10 +54,11 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
|||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
|
@ -70,28 +71,60 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||
|
||||
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
For the **pre-training datasets**, only the `prompt` column will be used for training, for example:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
[
|
||||
{"text": "document"},
|
||||
{"text": "document"}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "text"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Remember to set `"ranking": true` for the preference datasets.
|
||||
For the **preference datasets**, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
The dataset in sharegpt format should follow the below format:
|
||||
The dataset in **sharegpt** format should follow the below format:
|
||||
|
||||
```json
|
||||
[
|
||||
|
@ -112,10 +145,12 @@ The dataset in sharegpt format should follow the below format:
|
|||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"system": "system",
|
||||
|
@ -132,4 +167,46 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||
|
||||
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
|
||||
|
||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||
We also supports the dataset in the **openai** format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "system prompt (optional)"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "user instruction"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "model response"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant",
|
||||
"system_tag": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Pre-training datasets and preference datasets are **incompatible** with the sharegpt format yet.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
|
||||
如果您使用自定义数据集,请务必按照以下格式在 `dataset_info.json` 文件中添加**数据集描述**。我们在下面也提供了一些例子。
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
|
@ -33,7 +33,7 @@
|
|||
}
|
||||
```
|
||||
|
||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
||||
然后,可通过使用 `--dataset 数据集名称` 参数加载自定义数据集。
|
||||
|
||||
----
|
||||
|
||||
|
@ -54,10 +54,11 @@
|
|||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
|
@ -70,28 +71,60 @@
|
|||
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
对于**预训练数据集**,仅 `prompt` 列中的内容会用于模型训练,例如:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
[
|
||||
{"text": "document"},
|
||||
{"text": "document"}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "text"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
添加偏好数据集需要额外指定 `"ranking": true`。
|
||||
对于**偏好数据集**,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
而 sharegpt 格式的数据集按照以下方式组织:
|
||||
而 **sharegpt** 格式的数据集按照以下方式组织:
|
||||
|
||||
```json
|
||||
[
|
||||
|
@ -112,10 +145,12 @@
|
|||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"system": "system",
|
||||
|
@ -132,4 +167,46 @@
|
|||
|
||||
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
|
||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||
我们同样支持 **openai** 格式的数据集:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "系统提示词(选填)"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "用户指令"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "模型回答"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant",
|
||||
"system_tag": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
预训练数据集和偏好数据集**尚不支持** sharegpt 格式。
|
||||
|
|
|
@ -19,7 +19,7 @@ import pandas as pd
|
|||
|
||||
_CITATION = """\
|
||||
@article{huang2023ceval,
|
||||
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
|
||||
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
|
||||
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
|
||||
journal={arXiv preprint arXiv:2305.08322},
|
||||
year={2023}
|
||||
|
@ -133,25 +133,19 @@ class Ceval(datasets.GeneratorBasedBuilder):
|
|||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "test", f"{task_name}_test.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "test", f"{task_name}_test.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "val", f"{task_name}_val.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "val", f"{task_name}_val.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "dev", f"{task_name}_dev.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -37,73 +37,73 @@ _LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 Internatio
|
|||
_URL = "cmmlu.zip"
|
||||
|
||||
task_list = [
|
||||
'agronomy',
|
||||
'anatomy',
|
||||
'ancient_chinese',
|
||||
'arts',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'chinese_civil_service_exam',
|
||||
'chinese_driving_rule',
|
||||
'chinese_food_culture',
|
||||
'chinese_foreign_policy',
|
||||
'chinese_history',
|
||||
'chinese_literature',
|
||||
'chinese_teacher_qualification',
|
||||
'clinical_knowledge',
|
||||
'college_actuarial_science',
|
||||
'college_education',
|
||||
'college_engineering_hydrology',
|
||||
'college_law',
|
||||
'college_mathematics',
|
||||
'college_medical_statistics',
|
||||
'college_medicine',
|
||||
'computer_science',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'construction_project_management',
|
||||
'economics',
|
||||
'education',
|
||||
'electrical_engineering',
|
||||
'elementary_chinese',
|
||||
'elementary_commonsense',
|
||||
'elementary_information_and_technology',
|
||||
'elementary_mathematics',
|
||||
'ethnology',
|
||||
'food_science',
|
||||
'genetics',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_geography',
|
||||
'high_school_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_politics',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'journalism',
|
||||
'jurisprudence',
|
||||
'legal_and_moral_basis',
|
||||
'logical',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'marxist_theory',
|
||||
'modern_chinese',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_study',
|
||||
'sociology',
|
||||
'sports_science',
|
||||
'traditional_chinese_medicine',
|
||||
'virology',
|
||||
'world_history',
|
||||
'world_religions',
|
||||
"agronomy",
|
||||
"anatomy",
|
||||
"ancient_chinese",
|
||||
"arts",
|
||||
"astronomy",
|
||||
"business_ethics",
|
||||
"chinese_civil_service_exam",
|
||||
"chinese_driving_rule",
|
||||
"chinese_food_culture",
|
||||
"chinese_foreign_policy",
|
||||
"chinese_history",
|
||||
"chinese_literature",
|
||||
"chinese_teacher_qualification",
|
||||
"clinical_knowledge",
|
||||
"college_actuarial_science",
|
||||
"college_education",
|
||||
"college_engineering_hydrology",
|
||||
"college_law",
|
||||
"college_mathematics",
|
||||
"college_medical_statistics",
|
||||
"college_medicine",
|
||||
"computer_science",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"construction_project_management",
|
||||
"economics",
|
||||
"education",
|
||||
"electrical_engineering",
|
||||
"elementary_chinese",
|
||||
"elementary_commonsense",
|
||||
"elementary_information_and_technology",
|
||||
"elementary_mathematics",
|
||||
"ethnology",
|
||||
"food_science",
|
||||
"genetics",
|
||||
"global_facts",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_geography",
|
||||
"high_school_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_politics",
|
||||
"human_sexuality",
|
||||
"international_law",
|
||||
"journalism",
|
||||
"jurisprudence",
|
||||
"legal_and_moral_basis",
|
||||
"logical",
|
||||
"machine_learning",
|
||||
"management",
|
||||
"marketing",
|
||||
"marxist_theory",
|
||||
"modern_chinese",
|
||||
"nutrition",
|
||||
"philosophy",
|
||||
"professional_accounting",
|
||||
"professional_law",
|
||||
"professional_medicine",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_study",
|
||||
"sociology",
|
||||
"sports_science",
|
||||
"traditional_chinese_medicine",
|
||||
"virology",
|
||||
"world_history",
|
||||
"world_religions",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -136,25 +136,19 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
|||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "test", f"{task_name}_test.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "data", "test", f"{task_name}_test.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "val", f"{task_name}_val.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "data", "val", f"{task_name}_val.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "data", "dev", f"{task_name}_dev.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -10,7 +10,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--finetuning_type full \
|
||||
--use_badam \
|
||||
--badam_switch_mode descending \
|
||||
--badam_switch_block_every 50 \
|
||||
--badam_switch_interval 50 \
|
||||
--badam_verbose 2 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
|
||||
--overwrite_cache \
|
||||
|
|
|
@ -7,7 +7,7 @@ pip install "bitsandbytes>=0.43.0"
|
|||
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||
--config_file ../../accelerate/fsdp_config.yaml \
|
||||
../../../src/train_bash.py \
|
||||
../../../src/train.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-70b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path ../../../models/llama2-7b-pro \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -6,7 +6,7 @@ python -m torch.distributed.run \
|
|||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/single_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
deepspeed --num_gpus 4 ../../src/train.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template fewshot \
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
# add `--visual_inputs True` to load MLLM
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/bin/bash
|
||||
# ZeRO-3 enables weight sharding on multiple GPUs
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
deepspeed --num_gpus 4 ../../src/train.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/master_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/single_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage dpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage orpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage ppo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/bin/bash
|
||||
# use `--tokenized_path` in training script to load data
|
||||
|
||||
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES= llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage pt \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage rm \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path llava-hf/llava-1.5-7b-hf \
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
# DO NOT use quantized model or quantization_bit when merging lora weights
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
# NEED TO run `merge.sh` before using this script
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
|
||||
--model_name_or_path ../../models/llama2-7b-sft \
|
||||
--template default \
|
||||
--export_dir ../../models/llama2-7b-sft-int4 \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||
|
|
|
@ -16,3 +16,4 @@ sse-starlette
|
|||
matplotlib
|
||||
fire
|
||||
packaging
|
||||
pyyaml
|
||||
|
|
|
@ -3,24 +3,22 @@
|
|||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.chat import ChatModel
|
||||
|
||||
|
||||
def calculate_flops(
|
||||
model_name_or_path: str,
|
||||
batch_size: Optional[int] = 1,
|
||||
seq_length: Optional[int] = 256,
|
||||
flash_attn: Optional[bool] = False,
|
||||
batch_size: int = 1,
|
||||
seq_length: int = 256,
|
||||
flash_attn: str = "auto",
|
||||
):
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
@ -25,12 +25,12 @@ BASE_BS = 4_000_000 # from llama paper
|
|||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
stage: Optional[str] = "sft",
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
|
||||
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
|
||||
stage: Literal["pt", "sft"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||
):
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
|
@ -54,9 +54,7 @@ def calculate_lr(
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||
)
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
for batch in tqdm(dataloader):
|
||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# coding=utf-8
|
||||
# Calculates the ppl on the dataset of the pre-trained models.
|
||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model, load_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
train_on_prompt: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
chosen_features = []
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
|
||||
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
|
||||
attention_mask = [1] * (prompt_len + answer_len)
|
||||
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
|
||||
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
|
||||
|
||||
return super().__call__(chosen_features)
|
||||
|
||||
|
||||
def cal_ppl(
|
||||
model_name_or_path: str,
|
||||
save_name: str,
|
||||
batch_size: int = 4,
|
||||
stage: Literal["pt", "sft", "rm"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024,
|
||||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage=stage,
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=max_samples,
|
||||
train_on_prompt=train_on_prompt,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
elif stage == "rm":
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
total_ppl = 0
|
||||
perplexities = []
|
||||
batch: Dict[str, "torch.Tensor"]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader):
|
||||
batch = batch.to(model.device)
|
||||
outputs = model(**batch)
|
||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
|
||||
loss_mask = shift_labels != IGNORE_INDEX
|
||||
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
|
||||
flatten_labels = shift_labels.contiguous().view(-1)
|
||||
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
||||
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
total_ppl += sentence_logps.exp().sum().item()
|
||||
perplexities.extend(sentence_logps.exp().tolist())
|
||||
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(cal_ppl)
|
|
@ -3,7 +3,6 @@
|
|||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
|
@ -15,10 +14,10 @@ from llmtuner.model import load_tokenizer
|
|||
|
||||
def length_cdf(
|
||||
model_name_or_path: str,
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
interval: Optional[int] = 1000,
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
interval: int = 1000,
|
||||
):
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
|
|
1
setup.py
1
setup.py
|
@ -52,6 +52,7 @@ def main():
|
|||
python_requires=">=3.8.0",
|
||||
install_requires=get_requires(),
|
||||
extras_require=extra_require,
|
||||
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import ChatModel, create_app
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,49 +0,0 @@
|
|||
from llmtuner import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
|
||||
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,9 +0,0 @@
|
|||
from llmtuner import Evaluator
|
||||
|
||||
|
||||
def main():
|
||||
Evaluator().eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,9 +0,0 @@
|
|||
from llmtuner import export_model
|
||||
|
||||
|
||||
def main():
|
||||
export_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,11 +1,3 @@
|
|||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||
|
||||
from .api import create_app
|
||||
from .chat import ChatModel
|
||||
from .eval import Evaluator
|
||||
from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.7.0"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
__version__ = "0.7.1.dev0"
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from .app import create_app
|
||||
|
||||
|
||||
__all__ = ["create_app"]
|
|
@ -1,36 +1,29 @@
|
|||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .chat import (
|
||||
create_chat_completion_response,
|
||||
create_score_evaluation_response,
|
||||
create_stream_chat_completion_response,
|
||||
)
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
|
@ -47,23 +40,8 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
|
|||
torch_gc()
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
|
@ -71,162 +49,58 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
@app.get(
|
||||
"/v1/models",
|
||||
response_model=ModelList,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
|
||||
else:
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
if request.stream:
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||
generate = create_stream_chat_completion_response(request, chat_model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
else:
|
||||
return await create_chat_completion_response(request, chat_model)
|
||||
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
stop=request.stop
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
response_message = ChatCompletionMessage(
|
||||
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||
)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def stream_chat_completion(
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
async for new_token in chat_model.astream_chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
stop=request.stop
|
||||
):
|
||||
if len(new_token) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
@app.post(
|
||||
"/v1/score/evaluation",
|
||||
response_model=ScoreEvaluationResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
return await create_score_evaluation_response(request, chat_model)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def run_api() -> None:
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.packages import is_fastapi_availble
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionStreamResponseChoice,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Role,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..chat import ChatModel
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
|
||||
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
return input_messages, system, tools
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
completion_id: str,
|
||||
model: str,
|
||||
delta: "ChatCompletionMessage",
|
||||
index: Optional[int] = 0,
|
||||
finish_reason: Optional["Finish"] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
||||
return jsonify(chunk)
|
||||
|
||||
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
||||
|
||||
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
||||
)
|
||||
async for new_token in chat_model.astream_chat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
if len(new_token) != 0:
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
||||
)
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
|
@ -0,0 +1,20 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
|
@ -51,7 +51,7 @@ class FunctionAvailable(BaseModel):
|
|||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
@ -87,7 +87,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
@ -100,7 +100,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
|||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
|
@ -109,11 +109,11 @@ class ChatCompletionResponse(BaseModel):
|
|||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
|
@ -123,7 +123,7 @@ class ScoreEvaluationRequest(BaseModel):
|
|||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
id: str
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import get_infer_args
|
||||
from .hf_engine import HuggingfaceEngine
|
||||
from .vllm_engine import VllmEngine
|
||||
|
@ -95,3 +96,45 @@ class ChatModel:
|
|||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import __version__
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
|
||||
USAGE = """
|
||||
Usage:
|
||||
llamafactory-cli api -h: launch an API server
|
||||
llamafactory-cli chat -h: launch a chat interface in CLI
|
||||
llamafactory-cli eval -h: do evaluation
|
||||
llamafactory-cli export -h: merge LoRA adapters and export model
|
||||
llamafactory-cli train -h: do training
|
||||
llamafactory-cli webchat -h: launch a chat interface in Web UI
|
||||
llamafactory-cli webui: launch LlamaBoard
|
||||
llamafactory-cli version: show version info
|
||||
"""
|
||||
|
||||
|
||||
@unique
|
||||
class Command(str, Enum):
|
||||
API = "api"
|
||||
CHAT = "chat"
|
||||
EVAL = "eval"
|
||||
EXPORT = "export"
|
||||
TRAIN = "train"
|
||||
WEBDEMO = "webchat"
|
||||
WEBUI = "webui"
|
||||
VERSION = "version"
|
||||
HELP = "help"
|
||||
|
||||
|
||||
def main():
|
||||
command = sys.argv.pop(1)
|
||||
if command == Command.API:
|
||||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
run_chat()
|
||||
elif command == Command.EVAL:
|
||||
run_eval()
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
run_web_demo()
|
||||
elif command == Command.WEBUI:
|
||||
run_web_ui()
|
||||
elif command == Command.VERSION:
|
||||
print("Welcome to LLaMA Factory, version {}".format(__version__))
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}".format(command))
|
|
@ -1,4 +0,0 @@
|
|||
from .evaluator import Evaluator
|
||||
|
||||
|
||||
__all__ = ["Evaluator"]
|
|
@ -118,6 +118,5 @@ class Evaluator:
|
|||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
def run_eval() -> None:
|
||||
Evaluator().eval()
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
|
||||
from .constants import LOG_FILE_NAME
|
||||
from .logging import get_logger
|
||||
from .constants import TRAINER_LOG
|
||||
from .logging import LoggerHandler, get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
|
||||
|
||||
|
@ -33,57 +38,92 @@ class FixValueHeadModelCallback(TrainerCallback):
|
|||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.in_training = False
|
||||
self.start_time = time.time()
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
""" Progress """
|
||||
self.start_time = 0
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
""" Status """
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def timing(self):
|
||||
def _set_abort(self, signum, frame) -> None:
|
||||
self.aborted = True
|
||||
|
||||
def _reset(self, max_steps: int = 0) -> None:
|
||||
self.start_time = time.time()
|
||||
self.cur_steps = 0
|
||||
self.max_steps = max_steps
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
|
||||
def _timing(self, cur_steps: int) -> None:
|
||||
cur_time = time.time()
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
|
||||
self.cur_steps = cur_steps
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def _create_thread_pool(self, output_dir: str) -> None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _close_thread_pool(self) -> None:
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
"""
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = True
|
||||
self.start_time = time.time()
|
||||
self.max_steps = state.max_steps
|
||||
|
||||
if args.save_on_each_node:
|
||||
if not state.is_local_process_zero:
|
||||
return
|
||||
else:
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = False
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
|
@ -91,42 +131,30 @@ class LogCallback(TrainerCallback):
|
|||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.cur_steps = state.global_step
|
||||
self.timing()
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_predict(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||
):
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if args.save_on_each_node:
|
||||
if not state.is_local_process_zero:
|
||||
return
|
||||
else:
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
self._timing(cur_steps=state.global_step)
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
|
@ -141,16 +169,16 @@ class LogCallback(TrainerCallback):
|
|||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
if self.runner is not None:
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"]
|
||||
)
|
||||
)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
|
@ -158,9 +186,28 @@ class LogCallback(TrainerCallback):
|
|||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
if self.do_train:
|
||||
return
|
||||
|
||||
if self.aborted:
|
||||
sys.exit(0)
|
||||
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
||||
if has_length(eval_dataloader):
|
||||
if self.max_steps == 0:
|
||||
self.max_steps = len(eval_dataloader)
|
||||
self.cur_steps += 1
|
||||
self.timing()
|
||||
self._reset(max_steps=len(eval_dataloader))
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
self._timing(cur_steps=self.cur_steps + 1)
|
||||
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
|
|
@ -24,8 +24,6 @@ IGNORE_INDEX = -100
|
|||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
MLLM_LIST = ["LLaVA1.5"]
|
||||
|
@ -34,10 +32,16 @@ MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral
|
|||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
RUNNING_LOG = "running_log.txt"
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINER_CONFIG = "trainer_config.yaml"
|
||||
|
||||
TRAINER_LOG = "trainer_log.jsonl"
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .constants import RUNNING_LOG
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
|
@ -7,19 +11,35 @@ class LoggerHandler(logging.Handler):
|
|||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
self.log = ""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
self.setLevel(logging.INFO)
|
||||
self.setFormatter(formatter)
|
||||
|
||||
def reset(self):
|
||||
self.log = ""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||
if os.path.exists(self.running_log):
|
||||
os.remove(self.running_log)
|
||||
|
||||
def emit(self, record):
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _write_log(self, log_entry: str) -> None:
|
||||
with open(self.running_log, "a", encoding="utf-8") as f:
|
||||
f.write(log_entry + "\n\n")
|
||||
|
||||
def emit(self, record) -> None:
|
||||
if record.name == "httpx":
|
||||
return
|
||||
|
||||
log_entry = self.format(record)
|
||||
self.log += log_entry
|
||||
self.log += "\n\n"
|
||||
self.thread_pool.submit(self._write_log, log_entry)
|
||||
|
||||
def close(self) -> None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
return super().close()
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
|
@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
|
|||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
@ -21,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
smoothed = []
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
|
@ -30,7 +31,27 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
for log in trainer_log:
|
||||
if log.get("loss", None):
|
||||
steps.append(log["current_steps"])
|
||||
losses.append(log["loss"])
|
||||
|
||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
|
|
@ -221,16 +221,18 @@ class BAdamArgument:
|
|||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
},
|
||||
)
|
||||
badam_update_ratio: float = field(
|
||||
default=0.0,
|
||||
default=0.05,
|
||||
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
|
@ -308,6 +310,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
|||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import TRAINER_CONFIG
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
|
@ -251,7 +252,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
files = os.listdir(training_args.output_dir)
|
||||
if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG):
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from .tuner import export_model, run_exp
|
||||
|
||||
|
||||
__all__ = ["export_model", "run_exp"]
|
|
@ -165,13 +165,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
|
|
@ -113,15 +113,15 @@ class CustomORPOTrainer(DPOTrainer):
|
|||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
|
||||
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
|
||||
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
|
||||
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
|
||||
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
|
||||
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean()
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().mean().cpu()
|
||||
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().mean().cpu()
|
||||
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().mean().cpu()
|
||||
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().mean().cpu()
|
||||
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
|
||||
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().mean().cpu()
|
||||
|
||||
return batch_loss, metrics
|
||||
|
|
|
@ -23,9 +23,9 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
callbacks.append(LogCallback(training_args.output_dir))
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
|
@ -43,7 +43,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
|||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.export_dir is None:
|
||||
|
@ -88,7 +88,3 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
|||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||
except Exception:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_exp()
|
||||
|
|
|
@ -317,14 +317,14 @@ def _create_badam_optimizer(
|
|||
base_optimizer=base_optimizer,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
block_prefix_list=None,
|
||||
switch_block_every=finetuning_args.badam_switch_block_every,
|
||||
switch_block_every=finetuning_args.badam_switch_interval,
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
|
||||
f"switch block every {finetuning_args.badam_switch_interval} steps, "
|
||||
f"default start block is {finetuning_args.badam_start_block}"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from .interface import create_ui, create_web_demo
|
||||
|
||||
|
||||
__all__ = ["create_ui", "create_web_demo"]
|
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from yaml import safe_dump, safe_load
|
||||
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
|
@ -16,6 +17,7 @@ from ..extras.constants import (
|
|||
TRAINING_STAGES,
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
|
@ -24,12 +26,15 @@ if is_gradio_available():
|
|||
import gradio as gr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
USER_CONFIG = "user_config.yaml"
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
|
@ -47,7 +52,7 @@ def get_save_path(config_path: str) -> os.PathLike:
|
|||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
@ -60,13 +65,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
|||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
safe_dump(user_config, f)
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
@ -74,7 +79,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
|||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
|
||||
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
return str(get_save_path(config_path))
|
||||
|
||||
|
@ -127,11 +132,15 @@ def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
|||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
if dataset_dir == "ONLINE":
|
||||
logger.info("dataset_dir is ONLINE, using online dataset.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as err:
|
||||
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
return {}
|
||||
|
||||
|
||||
|
|
|
@ -36,9 +36,9 @@ def create_chat_box(
|
|||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
max_new_tokens = gr.Slider(8, 4096, value=512, step=1)
|
||||
top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
|
|
|
@ -21,25 +21,25 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
|
@ -52,19 +52,19 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, process_bar]
|
||||
output_elems = [output_box, progress_bar]
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict, Generator, List
|
|||
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train import export_model
|
||||
from ...train.tuner import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
@ -85,7 +85,7 @@ def save_model(
|
|||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||
|
|
|
@ -27,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
|
@ -52,10 +52,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
|
||||
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
|
||||
|
@ -71,10 +71,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
|
||||
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
|
||||
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
|
||||
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
|
||||
optim = gr.Textbox(value="adamw_torch")
|
||||
|
||||
with gr.Row():
|
||||
|
@ -124,7 +124,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
|
||||
num_layer_trainable = gr.Slider(minimum=1, maximum=128, value=2, step=1)
|
||||
name_module_trainable = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||
|
@ -136,10 +136,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
|
||||
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
|
@ -180,9 +180,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
|
||||
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||
|
@ -193,9 +193,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
with gr.Accordion(open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox()
|
||||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
|
||||
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
|
||||
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
|
||||
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
galore_target = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
|
@ -210,6 +210,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as badam_tab:
|
||||
with gr.Row():
|
||||
use_badam = gr.Checkbox()
|
||||
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
|
||||
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
|
||||
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
|
||||
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
|
||||
|
||||
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
badam_tab=badam_tab,
|
||||
use_badam=use_badam,
|
||||
badam_mode=badam_mode,
|
||||
badam_switch_mode=badam_switch_mode,
|
||||
badam_switch_interval=badam_switch_interval,
|
||||
badam_update_ratio=badam_update_ratio,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
arg_save_btn = gr.Button()
|
||||
|
@ -225,7 +245,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
@ -243,14 +263,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, process_bar, loss_viewer]
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
|
|
|
@ -41,7 +41,7 @@ class Engine:
|
|||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
|
||||
|
@ -51,7 +51,7 @@ class Engine:
|
|||
|
||||
yield self._update_component(init_dict)
|
||||
|
||||
if self.runner.alive and not self.demo_mode and not self.pure_chat:
|
||||
if self.runner.running and not self.demo_mode and not self.pure_chat:
|
||||
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._update_component({"train.resume_btn": {"value": True}})
|
||||
|
|
|
@ -68,5 +68,9 @@ def create_web_demo() -> gr.Blocks:
|
|||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
def run_web_ui() -> None:
|
||||
create_ui().queue().launch()
|
||||
|
||||
|
||||
def run_web_demo() -> None:
|
||||
create_web_demo().queue().launch()
|
||||
|
|
|
@ -891,6 +891,87 @@ LOCALES = {
|
|||
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"badam_tab": {
|
||||
"en": {
|
||||
"label": "BAdam configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Конфигурации BAdam",
|
||||
},
|
||||
"zh": {
|
||||
"label": "BAdam 参数设置",
|
||||
},
|
||||
},
|
||||
"use_badam": {
|
||||
"en": {
|
||||
"label": "Use BAdam",
|
||||
"info": "Enable the BAdam optimizer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать BAdam",
|
||||
"info": "Включите оптимизатор BAdam.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 BAdam",
|
||||
"info": "使用 BAdam 优化器。",
|
||||
},
|
||||
},
|
||||
"badam_mode": {
|
||||
"en": {
|
||||
"label": "BAdam mode",
|
||||
"info": "Whether to use layer-wise or ratio-wise BAdam optimizer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Режим BAdam",
|
||||
"info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "BAdam 模式",
|
||||
"info": "使用 layer-wise 或 ratio-wise BAdam 优化器。",
|
||||
},
|
||||
},
|
||||
"badam_switch_mode": {
|
||||
"en": {
|
||||
"label": "Switch mode",
|
||||
"info": "The strategy of picking block to update for layer-wise BAdam.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Режим переключения",
|
||||
"info": "Стратегия выбора блока для обновления для послойного BAdam.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "切换策略",
|
||||
"info": "Layer-wise BAdam 优化器的块切换策略。",
|
||||
},
|
||||
},
|
||||
"badam_switch_interval": {
|
||||
"en": {
|
||||
"label": "Switch interval",
|
||||
"info": "Number of steps to update the block for layer-wise BAdam.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Интервал переключения",
|
||||
"info": "количество шагов для обновления блока для пошагового BAdam.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "切换频率",
|
||||
"info": "Layer-wise BAdam 优化器的块切换频率。",
|
||||
},
|
||||
},
|
||||
"badam_update_ratio": {
|
||||
"en": {
|
||||
"label": "Update ratio",
|
||||
"info": "The ratio of the update for ratio-wise BAdam.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Коэффициент обновления",
|
||||
"info": "Коэффициент обновления для BAdam с учётом соотношений.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "Block 更新比例",
|
||||
"info": "Ratio-wise BAdam 优化器的更新比例。",
|
||||
},
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview command",
|
||||
|
@ -1368,7 +1449,7 @@ ALERTS = {
|
|||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"ru": "Прервано, ожидание завершения...",
|
||||
"zh": "训练中断,正在等待线程结束……",
|
||||
"zh": "训练中断,正在等待进程结束……",
|
||||
},
|
||||
"info_aborted": {
|
||||
"en": "Ready.",
|
||||
|
|
|
@ -1,22 +1,19 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
import transformers
|
||||
import psutil
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
|
@ -34,24 +31,18 @@ class Runner:
|
|||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
if self.trainer is not None:
|
||||
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
|
||||
os.kill(children.pid, signal.SIGABRT)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
|
@ -85,13 +76,11 @@ class Runner:
|
|||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
self.thread = None
|
||||
self.trainer = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
|
@ -147,12 +136,12 @@ class Runner:
|
|||
shift_attn=get("train.shift_attn"),
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
use_galore=get("train.use_galore"),
|
||||
use_badam=get("train.use_badam"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
|
@ -198,6 +187,12 @@ class Runner:
|
|||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
if args["use_badam"]:
|
||||
args["badam_mode"] = get("train.badam_mode")
|
||||
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
|
@ -237,7 +232,6 @@ class Runner:
|
|||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if get("eval.predict"):
|
||||
args["do_predict"] = True
|
||||
|
@ -263,11 +257,12 @@ class Runner:
|
|||
gr.Warning(error)
|
||||
yield {output_box: error}
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
env = deepcopy(os.environ)
|
||||
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data):
|
||||
|
@ -283,10 +278,10 @@ class Runner:
|
|||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
model_name = get("top.model_name")
|
||||
finetuning_type = get("top.finetuning_type")
|
||||
|
@ -294,28 +289,31 @@ class Runner:
|
|||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
|
||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
while self.trainer is not None:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
|
||||
return_dict = {
|
||||
output_box: self.logger_handler.log,
|
||||
process_bar: update_process_bar(self.trainer_callback),
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
if running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
|
||||
yield return_dict
|
||||
|
||||
time.sleep(2)
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
|
@ -330,16 +328,11 @@ class Runner:
|
|||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info),
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
def save_args(self, data: dict):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from yaml import safe_dump
|
||||
|
||||
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
|
@ -12,30 +15,6 @@ if is_gradio_available():
|
|||
import gradio as gr
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
|
||||
if not callback.max_steps:
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type != "lora":
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
|
@ -57,14 +36,18 @@ def check_json_schema(text: str, lang: str) -> None:
|
|||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
no_skip_keys = ["packing"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
|
||||
for k, v in args.items():
|
||||
if v is not None and v is not False and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)]
|
||||
for k, v in clean_cmd(args).items():
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
return cmd_text
|
||||
|
@ -76,29 +59,49 @@ def get_eval_results(path: os.PathLike) -> str:
|
|||
return "```json\n{}\n```\n".format(result)
|
||||
|
||||
|
||||
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
|
||||
log_file = os.path.join(output_path, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file) or not is_matplotlib_available():
|
||||
return
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
log_info: Dict[str, Any] = json.loads(line)
|
||||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
|
||||
if len(losses) == 0:
|
||||
return
|
||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
|
||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
with open(running_log_path, "r", encoding="utf-8") as f:
|
||||
running_log = f.read()
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
trainer_log: List[Dict[str, Any]] = []
|
||||
with open(trainer_log_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
trainer_log.append(json.loads(line))
|
||||
|
||||
if len(trainer_log) != 0:
|
||||
latest_log = trainer_log[-1]
|
||||
percentage = latest_log["percentage"]
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
latest_log["current_steps"],
|
||||
latest_log["total_steps"],
|
||||
latest_log["elapsed_time"],
|
||||
latest_log["remaining_time"],
|
||||
)
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
if do_train and is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
|
||||
|
||||
def save_cmd(args: Dict[str, Any]) -> str:
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
||||
safe_dump(clean_cmd(args), f)
|
||||
|
||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from llmtuner import run_exp
|
||||
from llmtuner.train.tuner import run_exp
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -7,7 +7,7 @@ def main():
|
|||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
run_exp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -1,9 +0,0 @@
|
|||
from llmtuner import create_ui
|
||||
|
||||
|
||||
def main():
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,9 +0,0 @@
|
|||
from llmtuner import create_web_demo
|
||||
|
||||
|
||||
def main():
|
||||
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue