Merge branch 'main' into feat/support_ms
This commit is contained in:
commit
6382efec52
63
README.md
63
README.md
|
@ -1,4 +1,4 @@
|
||||||
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
|
![# LLaMA Factory](assets/logo.png)
|
||||||
|
|
||||||
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||||
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
||||||
|
@ -44,16 +44,24 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
![benchmark](assets/benchmark.svg)
|
![benchmark](assets/benchmark.svg)
|
||||||
|
|
||||||
|
<details><summary>Definitions</summary>
|
||||||
|
|
||||||
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
||||||
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
||||||
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
||||||
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
|
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/12/01] We supported **[ModelScope Hub](https://www.modelscope.cn/models)** to accelerate model downloading. Add environment variable `USE_MODELSCOPE_HUB=1` to your command line, then you can use the model-id of ModelScope Hub.
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage.
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||||
|
|
||||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||||
|
|
||||||
|
@ -79,6 +87,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Default module | Template |
|
| Model | Model size | Default module | Template |
|
||||||
|
@ -93,6 +103,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
@ -198,13 +209,13 @@ huggingface-cli login
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 65B |
|
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||||
| ------ | ---- | ----- | ----- | ----- | ------ |
|
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| Full | 16 | 140GB | 240GB | 520GB | 1200GB |
|
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB |
|
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
|
@ -231,31 +242,26 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use ModelScope Models
|
### Use ModelScope Models (optional)
|
||||||
|
|
||||||
If you have trouble with downloading models from HuggingFace, we have supported ModelScope Hub. To use LLaMA-Factory together with ModelScope, please add a environment variable:
|
If you have trouble with downloading models from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||||
|
|
||||||
```shell
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1
|
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
|
||||||
>
|
|
||||||
> Please use integers only. 0 or not set for using HuggingFace hub. Other values will be treated as use ModelScope hub.
|
|
||||||
|
|
||||||
Then you can use LLaMA-Factory with ModelScope model-ids:
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
```shell
|
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||||
python src/train_bash.py \
|
... # arguments (same as above)
|
||||||
--model_name_or_path ZhipuAI/chatglm3-6b \
|
|
||||||
... other arguments
|
|
||||||
# You can find all model ids in this link: https://www.modelscope.cn/models
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Web demo also supports ModelScope, after setting the environment variable please run with this command:
|
LLaMA Board also supports using the models on the ModelScope Hub.
|
||||||
|
|
||||||
```shell
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Train on a single GPU
|
### Train on a single GPU
|
||||||
|
@ -472,6 +478,9 @@ python src/export_model.py \
|
||||||
--export_dir path_to_export
|
--export_dir path_to_export
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> Merging LoRA weights into a GPTQ quantized model is not supported.
|
||||||
|
|
||||||
### API Demo
|
### API Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
63
README_zh.md
63
README_zh.md
|
@ -1,4 +1,4 @@
|
||||||
# LLaMA Factory: 轻松的大模型训练与评估
|
![# LLaMA Factory](assets/logo.png)
|
||||||
|
|
||||||
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||||
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
|
||||||
|
@ -44,16 +44,24 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
|
|
||||||
![benchmark](assets/benchmark.svg)
|
![benchmark](assets/benchmark.svg)
|
||||||
|
|
||||||
|
<details><summary>变量定义</summary>
|
||||||
|
|
||||||
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
||||||
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
||||||
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
||||||
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[23/12/01] 我们支持了 **[魔搭ModelHub](https://www.modelscope.cn/models)** 进行模型下载加速。在启动命令前环境变量中增加 `USE_MODELSCOPE_HUB=1` 即可开启。
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||||
|
|
||||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||||
|
|
||||||
|
@ -79,6 +87,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
|
|
||||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||||
|
@ -93,6 +103,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
@ -198,13 +209,13 @@ huggingface-cli login
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B |
|
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||||
| ------- | ---- | ----- | ----- | ----- | ------ |
|
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| 全参数 | 16 | 140GB | 240GB | 520GB | 1200GB |
|
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB |
|
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
|
@ -231,31 +242,26 @@ pip install -r requirements.txt
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### 使用魔搭的模型
|
### 使用魔搭社区(可跳过)
|
||||||
|
|
||||||
如果下载HuggingFace模型存在问题,我们已经支持了魔搭的ModelHub,只需要添加一个环境变量:
|
如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||||
|
|
||||||
```shell
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1
|
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型)
|
||||||
>
|
|
||||||
> 该环境变量仅支持整数,0或者不设置代表使用HuggingFace,其他值代表使用ModelScope
|
|
||||||
|
|
||||||
之后就可以在命令行中指定魔搭的模型id:
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
```shell
|
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||||
python src/train_bash.py \
|
... # 参数同上
|
||||||
--model_name_or_path ZhipuAI/chatglm3-6b \
|
|
||||||
... other arguments
|
|
||||||
# 在这个链接中可以看到所有可用模型: https://www.modelscope.cn/models
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Web demo目前也支持了魔搭, 在设置环境变量后即可使用:
|
LLaMA Board 同样支持魔搭社区的模型下载。
|
||||||
|
|
||||||
```shell
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### 单 GPU 训练
|
### 单 GPU 训练
|
||||||
|
@ -472,6 +478,9 @@ python src/export_model.py \
|
||||||
--export_dir path_to_export
|
--export_dir path_to_export
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> 尚不支持 GPTQ 量化模型的 LoRA 权重合并及导出。
|
||||||
|
|
||||||
### API 服务
|
### API 服务
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 56 KiB |
Binary file not shown.
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 140 KiB |
|
@ -4,9 +4,10 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||||
"dataset_name": {
|
"dataset_name": {
|
||||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
||||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
||||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||||
"subset": "the name of the subset. (optional, default: None)",
|
"subset": "the name of the subset. (optional, default: None)",
|
||||||
|
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||||
"columns": {
|
"columns": {
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
|
|
||||||
```json
|
```json
|
||||||
"数据集名称": {
|
"数据集名称": {
|
||||||
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
|
"hf_hub_url": "Hugging Face 的仓库地址(若指定,则忽略下列三个参数)",
|
||||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||||
"file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)",
|
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
||||||
"subset": "数据集子集的名称(可选,默认:None)",
|
"subset": "数据集子集的名称(可选,默认:None)",
|
||||||
|
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||||
"columns": {
|
"columns": {
|
||||||
|
|
|
@ -291,10 +291,11 @@
|
||||||
"prompt": "content"
|
"prompt": "content"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"starcoder": {
|
"starcoder_python": {
|
||||||
"hf_hub_url": "bigcode/starcoderdata",
|
"hf_hub_url": "bigcode/starcoderdata",
|
||||||
"columns": {
|
"columns": {
|
||||||
"prompt": "content"
|
"prompt": "content"
|
||||||
}
|
},
|
||||||
|
"folder": "python"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
transformers>=4.31.0,<4.35.0
|
transformers>=4.36.0
|
||||||
datasets>=2.14.0
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.6.0
|
peft>=0.7.0
|
||||||
trl>=0.7.4
|
trl>=0.7.4
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
|
|
|
@ -7,4 +7,4 @@ from llmtuner.train import export_model, run_exp
|
||||||
from llmtuner.webui import create_ui, create_web_demo
|
from llmtuner.webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.3.2"
|
__version__ = "0.3.3"
|
||||||
|
|
|
@ -15,7 +15,9 @@ from llmtuner.api.protocol import (
|
||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage
|
ChatCompletionResponseUsage,
|
||||||
|
ScoreEvaluationRequest,
|
||||||
|
ScoreEvaluationResponse
|
||||||
)
|
)
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
@ -68,6 +70,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
if not chat_model.can_generate:
|
||||||
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
|
@ -156,6 +161,17 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
yield to_json(chunk)
|
yield to_json(chunk)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
|
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||||
|
if chat_model.can_generate:
|
||||||
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
|
if len(request.messages) == 0:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
|
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||||
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -81,3 +81,16 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreEvaluationRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[str]
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreEvaluationResponse(BaseModel):
|
||||||
|
id: Optional[str] = "scoreeval-default"
|
||||||
|
object: Optional[str] = "score.evaluation"
|
||||||
|
model: str
|
||||||
|
scores: List[float]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import tiktoken
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
@ -22,8 +23,11 @@ class ChatModel:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.can_generate = (finetuning_args.stage == "sft")
|
||||||
self.tokenizer.padding_side = "left"
|
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||||
|
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
|
)
|
||||||
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.system_prompt = data_args.system_prompt
|
self.system_prompt = data_args.system_prompt
|
||||||
|
@ -130,3 +134,41 @@ class ChatModel:
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
yield from streamer
|
yield from streamer
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs
|
||||||
|
) -> List[float]:
|
||||||
|
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=True)
|
||||||
|
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
batch_input,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||||
|
pad_to_multiple_of=8,
|
||||||
|
return_tensors="pt",
|
||||||
|
**kwargs
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
|
if getattr(self.model.config, "model_type", None) == "chatglm":
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for i in range(input_ids.size(0)):
|
||||||
|
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
|
scores.append(values[i, end_index].nan_to_num().item())
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
|
@ -24,27 +24,27 @@ def get_dataset(
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
|
|
||||||
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ("hf_hub", "ms_hub"):
|
if dataset_attr.load_from in ("hf_hub", "ms_hub"):
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
data_name = dataset_attr.subset
|
data_name = dataset_attr.subset
|
||||||
data_files = None
|
data_dir = dataset_attr.folder
|
||||||
elif dataset_attr.load_from == "script":
|
elif dataset_attr.load_from == "script":
|
||||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
data_name = dataset_attr.subset
|
data_name = dataset_attr.subset
|
||||||
data_files = None
|
|
||||||
elif dataset_attr.load_from == "file":
|
elif dataset_attr.load_from == "file":
|
||||||
data_path, data_name = None, None
|
data_files = []
|
||||||
data_files: List[str] = []
|
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
data_files.append(local_path)
|
||||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
raise ValueError("File not found.")
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
|
|
|
@ -541,9 +541,7 @@ register_template(
|
||||||
"[INST] {{query}} [/INST]"
|
"[INST] {{query}} [/INST]"
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[
|
sep=[]
|
||||||
" "
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -650,6 +648,23 @@ register_template(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="xuanyuan",
|
||||||
|
prefix=[
|
||||||
|
"{{system}}"
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"Human: {{query}} Assistant:"
|
||||||
|
],
|
||||||
|
system=(
|
||||||
|
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||||
|
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||||
|
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||||
|
),
|
||||||
|
sep=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="xverse",
|
name="xverse",
|
||||||
prefix=[
|
prefix=[
|
||||||
|
@ -707,6 +722,9 @@ register_template(
|
||||||
sep=[
|
sep=[
|
||||||
"<|im_end|>\n"
|
"<|im_end|>\n"
|
||||||
],
|
],
|
||||||
|
stop_words=[
|
||||||
|
"<|im_end|>"
|
||||||
|
],
|
||||||
efficient_eos=True
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
from enum import Enum
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
CHOICES = ["A", "B", "C", "D"]
|
CHOICES = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
@ -20,8 +21,6 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||||
|
|
||||||
SUPPORTED_MODELS = OrderedDict()
|
SUPPORTED_MODELS = OrderedDict()
|
||||||
|
|
||||||
ALL_OFFICIAL_MODELS = OrderedDict()
|
|
||||||
|
|
||||||
TRAINING_STAGES = {
|
TRAINING_STAGES = {
|
||||||
"Supervised Fine-Tuning": "sft",
|
"Supervised Fine-Tuning": "sft",
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
|
@ -30,9 +29,13 @@ TRAINING_STAGES = {
|
||||||
"Pre-Training": "pt"
|
"Pre-Training": "pt"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DownloadSource(str, Enum):
|
||||||
|
DEFAULT = "hf"
|
||||||
|
MODELSCOPE = "ms"
|
||||||
|
|
||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
models: Dict[str, Union[str, Dict[str, str]]],
|
models: Dict[str, Dict[DownloadSource, str]],
|
||||||
module: Optional[str] = None,
|
module: Optional[str] = None,
|
||||||
template: Optional[str] = None
|
template: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -42,14 +45,7 @@ def register_model_group(
|
||||||
prefix = name.split("-")[0]
|
prefix = name.split("-")[0]
|
||||||
else:
|
else:
|
||||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||||
|
SUPPORTED_MODELS[name] = path
|
||||||
ALL_OFFICIAL_MODELS[name] = [path] if isinstance(path, str) else list(path.values())
|
|
||||||
if not int(os.environ.get('USE_MODELSCOPE_HUB', '0')):
|
|
||||||
# If path is a string, we treat it as a huggingface model-id by default.
|
|
||||||
SUPPORTED_MODELS[name] = path["hf"] if isinstance(path, dict) else path
|
|
||||||
elif isinstance(path, dict) and "ms" in path:
|
|
||||||
# Use ModelScope modelhub
|
|
||||||
SUPPORTED_MODELS[name] = path["ms"]
|
|
||||||
if module is not None:
|
if module is not None:
|
||||||
DEFAULT_MODULE[prefix] = module
|
DEFAULT_MODULE[prefix] = module
|
||||||
if template is not None:
|
if template is not None:
|
||||||
|
@ -59,16 +55,16 @@ def register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Baichuan-7B-Base": {
|
"Baichuan-7B-Base": {
|
||||||
"hf": "baichuan-inc/Baichuan-7B",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||||
"ms": "baichuan-inc/baichuan-7B",
|
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
|
||||||
},
|
},
|
||||||
"Baichuan-13B-Base": {
|
"Baichuan-13B-Base": {
|
||||||
"hf": "baichuan-inc/Baichuan-13B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||||
"ms": "baichuan-inc/Baichuan-13B-Base",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
|
||||||
},
|
},
|
||||||
"Baichuan-13B-Chat": {
|
"Baichuan-13B-Chat": {
|
||||||
"hf": "baichuan-inc/Baichuan-13B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||||
"ms": "baichuan-inc/Baichuan-13B-Base",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="W_pack",
|
module="W_pack",
|
||||||
|
@ -79,20 +75,20 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Baichuan2-7B-Base": {
|
"Baichuan2-7B-Base": {
|
||||||
"hf": "baichuan-inc/Baichuan2-7B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||||
"ms": "baichuan-inc/Baichuan2-7B-Base",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
|
||||||
},
|
},
|
||||||
"Baichuan2-13B-Base": {
|
"Baichuan2-13B-Base": {
|
||||||
"hf": "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
"ms": "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
|
||||||
},
|
},
|
||||||
"Baichuan2-7B-Chat": {
|
"Baichuan2-7B-Chat": {
|
||||||
"hf": "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
"ms": "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
|
||||||
},
|
},
|
||||||
"Baichuan2-13B-Chat": {
|
"Baichuan2-13B-Chat": {
|
||||||
"hf": "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
"ms": "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="W_pack",
|
module="W_pack",
|
||||||
|
@ -103,16 +99,16 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"BLOOM-560M": {
|
"BLOOM-560M": {
|
||||||
"hf": "bigscience/bloom-560m",
|
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
||||||
"ms": "AI-ModelScope/bloom-560m",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
|
||||||
},
|
},
|
||||||
"BLOOM-3B": {
|
"BLOOM-3B": {
|
||||||
"hf": "bigscience/bloom-3b",
|
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
||||||
"ms": "AI-ModelScope/bloom-3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
|
||||||
},
|
},
|
||||||
"BLOOM-7B1": {
|
"BLOOM-7B1": {
|
||||||
"hf": "bigscience/bloom-7b1",
|
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
||||||
"ms": "AI-ModelScope/bloom-7b1",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="query_key_value"
|
module="query_key_value"
|
||||||
|
@ -122,16 +118,16 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"BLOOMZ-560M": {
|
"BLOOMZ-560M": {
|
||||||
"hf": "bigscience/bloomz-560m",
|
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
||||||
"ms": "AI-ModelScope/bloomz-560m",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
|
||||||
},
|
},
|
||||||
"BLOOMZ-3B": {
|
"BLOOMZ-3B": {
|
||||||
"hf": "bigscience/bloomz-3b",
|
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
||||||
"ms": "AI-ModelScope/bloomz-3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
|
||||||
},
|
},
|
||||||
"BLOOMZ-7B1-mt": {
|
"BLOOMZ-7B1-mt": {
|
||||||
"hf": "bigscience/bloomz-7b1-mt",
|
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
||||||
"ms": "AI-ModelScope/bloomz-7b1-mt",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="query_key_value"
|
module="query_key_value"
|
||||||
|
@ -141,12 +137,12 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"BlueLM-7B-Base": {
|
"BlueLM-7B-Base": {
|
||||||
"hf": "vivo-ai/BlueLM-7B-Base",
|
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||||
"ms": "vivo-ai/BlueLM-7B-Base",
|
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
|
||||||
},
|
},
|
||||||
"BlueLM-7B-Chat": {
|
"BlueLM-7B-Chat": {
|
||||||
"hf": "vivo-ai/BlueLM-7B-Chat",
|
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||||
"ms": "vivo-ai/BlueLM-7B-Chat",
|
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="bluelm"
|
template="bluelm"
|
||||||
|
@ -156,8 +152,8 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChatGLM2-6B-Chat": {
|
"ChatGLM2-6B-Chat": {
|
||||||
"hf": "THUDM/chatglm2-6b",
|
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||||
"ms": "ZhipuAI/chatglm2-6b",
|
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="query_key_value",
|
module="query_key_value",
|
||||||
|
@ -168,12 +164,12 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChatGLM3-6B-Base": {
|
"ChatGLM3-6B-Base": {
|
||||||
"hf": "THUDM/chatglm3-6b-base",
|
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||||
"ms": "ZhipuAI/chatglm3-6b-base",
|
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
|
||||||
},
|
},
|
||||||
"ChatGLM3-6B-Chat": {
|
"ChatGLM3-6B-Chat": {
|
||||||
"hf": "THUDM/chatglm3-6b",
|
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||||
"ms": "ZhipuAI/chatglm3-6b",
|
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="query_key_value",
|
module="query_key_value",
|
||||||
|
@ -184,59 +180,105 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChineseLLaMA2-1.3B": {
|
"ChineseLLaMA2-1.3B": {
|
||||||
"hf": "hfl/chinese-llama-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||||
"ms": "AI-ModelScope/chinese-llama-2-1.3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B": {
|
"ChineseLLaMA2-7B": {
|
||||||
"hf": "hfl/chinese-llama-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||||
"ms": "AI-ModelScope/chinese-llama-2-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B": {
|
"ChineseLLaMA2-13B": {
|
||||||
"hf": "hfl/chinese-llama-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||||
"ms": "AI-ModelScope/chinese-llama-2-13b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-1.3B-Chat": {
|
"ChineseLLaMA2-1.3B-Chat": {
|
||||||
"hf": "hfl/chinese-alpaca-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||||
"ms": "AI-ModelScope/chinese-alpaca-2-1.3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B-Chat": {
|
"ChineseLLaMA2-7B-Chat": {
|
||||||
"hf": "hfl/chinese-alpaca-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||||
"ms": "AI-ModelScope/chinese-alpaca-2-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B-Chat": {
|
"ChineseLLaMA2-13B-Chat": {
|
||||||
"hf": "hfl/chinese-alpaca-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||||
"ms": "AI-ModelScope/chinese-alpaca-2-13b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="llama2_zh"
|
template="llama2_zh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"DeepseekLLM-7B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
|
||||||
|
},
|
||||||
|
"DeepseekLLM-67B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
|
||||||
|
},
|
||||||
|
"DeepseekLLM-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
|
||||||
|
},
|
||||||
|
"DeepseekLLM-67B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"DeepseekCoder-6.7B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
|
||||||
|
},
|
||||||
|
"DeepseekCoder-33B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
|
||||||
|
},
|
||||||
|
"DeepseekCoder-6.7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
|
||||||
|
},
|
||||||
|
"DeepseekCoder-33B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="deepseekcoder"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Falcon-7B": {
|
"Falcon-7B": {
|
||||||
"hf": "tiiuae/falcon-7b",
|
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||||
"ms": "AI-ModelScope/falcon-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
|
||||||
},
|
},
|
||||||
"Falcon-40B": {
|
"Falcon-40B": {
|
||||||
"hf": "tiiuae/falcon-40b",
|
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||||
"ms": "AI-ModelScope/falcon-40b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
|
||||||
},
|
},
|
||||||
"Falcon-180B": {
|
"Falcon-180B": {
|
||||||
"hf": "tiiuae/falcon-180B",
|
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
||||||
"ms": "AI-ModelScope/falcon-180B",
|
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
|
||||||
},
|
},
|
||||||
"Falcon-7B-Chat": {
|
"Falcon-7B-Chat": {
|
||||||
"hf": "tiiuae/falcon-7b-instruct",
|
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
||||||
"ms": "AI-ModelScope/falcon-7b-instruct",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
|
||||||
},
|
},
|
||||||
"Falcon-40B-Chat": {
|
"Falcon-40B-Chat": {
|
||||||
"hf": "tiiuae/falcon-40b-instruct",
|
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
||||||
"ms": "AI-ModelScope/falcon-40b-instruct",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
|
||||||
},
|
},
|
||||||
"Falcon-180B-Chat": {
|
"Falcon-180B-Chat": {
|
||||||
"hf": "tiiuae/falcon-180B-chat",
|
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
||||||
"ms": "AI-ModelScope/falcon-180B-chat",
|
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="query_key_value",
|
module="query_key_value",
|
||||||
|
@ -247,20 +289,20 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": {
|
"InternLM-7B": {
|
||||||
"hf": "internlm/internlm-7b",
|
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||||
"ms": "Shanghai_AI_Laboratory/internlm-7b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
|
||||||
},
|
},
|
||||||
"InternLM-20B": {
|
"InternLM-20B": {
|
||||||
"hf": "internlm/internlm-20b",
|
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||||
"ms": "Shanghai_AI_Laboratory/internlm-20b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
|
||||||
},
|
},
|
||||||
"InternLM-7B-Chat": {
|
"InternLM-7B-Chat": {
|
||||||
"hf": "internlm/internlm-chat-7b",
|
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||||
"ms": "Shanghai_AI_Laboratory/internlm-chat-7b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
|
||||||
},
|
},
|
||||||
"InternLM-20B-Chat": {
|
"InternLM-20B-Chat": {
|
||||||
"hf": "internlm/internlm-chat-20b",
|
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||||
"ms": "Shanghai_AI_Laboratory/internlm-chat-20b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="intern"
|
template="intern"
|
||||||
|
@ -270,8 +312,8 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LingoWhale-8B": {
|
"LingoWhale-8B": {
|
||||||
"hf": "deeplang-ai/LingoWhale-8B",
|
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||||
"ms": "DeepLang/LingoWhale-8B",
|
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="qkv_proj"
|
module="qkv_proj"
|
||||||
|
@ -281,20 +323,20 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA-7B": {
|
"LLaMA-7B": {
|
||||||
"hf": "huggyllama/llama-7b",
|
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||||
"ms": "skyline2006/llama-7b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
|
||||||
},
|
},
|
||||||
"LLaMA-13B": {
|
"LLaMA-13B": {
|
||||||
"hf": "huggyllama/llama-13b",
|
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||||
"ms": "skyline2006/llama-13b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
|
||||||
},
|
},
|
||||||
"LLaMA-30B": {
|
"LLaMA-30B": {
|
||||||
"hf": "huggyllama/llama-30b",
|
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||||
"ms": "skyline2006/llama-30b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
|
||||||
},
|
},
|
||||||
"LLaMA-65B": {
|
"LLaMA-65B": {
|
||||||
"hf": "huggyllama/llama-65b",
|
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||||
"ms": "skyline2006/llama-65b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -303,28 +345,28 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA2-7B": {
|
"LLaMA2-7B": {
|
||||||
"hf": "meta-llama/Llama-2-7b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||||
"ms": "modelscope/Llama-2-7b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
|
||||||
},
|
},
|
||||||
"LLaMA2-13B": {
|
"LLaMA2-13B": {
|
||||||
"hf": "meta-llama/Llama-2-13b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||||
"ms": "modelscope/Llama-2-13b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
|
||||||
},
|
},
|
||||||
"LLaMA2-70B": {
|
"LLaMA2-70B": {
|
||||||
"hf": "meta-llama/Llama-2-70b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||||
"ms": "modelscope/Llama-2-70b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
|
||||||
},
|
},
|
||||||
"LLaMA2-7B-Chat": {
|
"LLaMA2-7B-Chat": {
|
||||||
"hf": "meta-llama/Llama-2-7b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||||
"ms": "modelscope/Llama-2-7b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
|
||||||
},
|
},
|
||||||
"LLaMA2-13B-Chat": {
|
"LLaMA2-13B-Chat": {
|
||||||
"hf": "meta-llama/Llama-2-13b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||||
"ms": "modelscope/Llama-2-13b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
|
||||||
},
|
},
|
||||||
"LLaMA2-70B-Chat": {
|
"LLaMA2-70B-Chat": {
|
||||||
"hf": "meta-llama/Llama-2-70b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||||
"ms": "modelscope/Llama-2-70b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="llama2"
|
template="llama2"
|
||||||
|
@ -334,12 +376,28 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B": {
|
"Mistral-7B": {
|
||||||
"hf": "mistralai/Mistral-7B-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||||
"ms": "AI-ModelScope/Mistral-7B-v0.1",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
|
||||||
},
|
},
|
||||||
"Mistral-7B-Chat": {
|
"Mistral-7B-Chat": {
|
||||||
"hf": "mistralai/Mistral-7B-Instruct-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
"ms": "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
||||||
|
},
|
||||||
|
"Mistral-7B-v0.2-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="mistral"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Mixtral-8x7B": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1"
|
||||||
|
},
|
||||||
|
"Mixtral-8x7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="mistral"
|
template="mistral"
|
||||||
|
@ -349,8 +407,8 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"OpenChat3.5-7B-Chat": {
|
"OpenChat3.5-7B-Chat": {
|
||||||
"hf": "openchat/openchat_3.5",
|
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
||||||
"ms": "myxiongmodel/openchat_3.5",
|
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="openchat"
|
template="openchat"
|
||||||
|
@ -360,8 +418,8 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Phi1.5-1.3B": {
|
"Phi1.5-1.3B": {
|
||||||
"hf": "microsoft/phi-1_5",
|
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||||
"ms": "allspace/PHI_1-5",
|
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="Wqkv"
|
module="Wqkv"
|
||||||
|
@ -370,37 +428,69 @@ register_model_group(
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
|
"Qwen-1.8B": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
|
||||||
|
},
|
||||||
"Qwen-7B": {
|
"Qwen-7B": {
|
||||||
"hf": "Qwen/Qwen-7B",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||||
"ms": "qwen/Qwen-7B",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
|
||||||
},
|
},
|
||||||
"Qwen-14B": {
|
"Qwen-14B": {
|
||||||
"hf": "Qwen/Qwen-14B",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||||
"ms": "qwen/Qwen-14B",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
|
||||||
|
},
|
||||||
|
"Qwen-72B": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
|
||||||
|
},
|
||||||
|
"Qwen-1.8B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
|
||||||
},
|
},
|
||||||
"Qwen-7B-Chat": {
|
"Qwen-7B-Chat": {
|
||||||
"hf": "Qwen/Qwen-7B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||||
"ms": "qwen/Qwen-7B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
|
||||||
},
|
},
|
||||||
"Qwen-14B-Chat": {
|
"Qwen-14B-Chat": {
|
||||||
"hf": "Qwen/Qwen-14B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||||
"ms": "qwen/Qwen-14B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
|
||||||
|
},
|
||||||
|
"Qwen-72B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
|
||||||
|
},
|
||||||
|
"Qwen-1.8B-int8-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
|
||||||
|
},
|
||||||
|
"Qwen-1.8B-int4-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
|
||||||
},
|
},
|
||||||
"Qwen-7B-int8-Chat": {
|
"Qwen-7B-int8-Chat": {
|
||||||
"hf": "Qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||||
"ms": "qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
|
||||||
},
|
},
|
||||||
"Qwen-7B-int4-Chat": {
|
"Qwen-7B-int4-Chat": {
|
||||||
"hf": "Qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||||
"ms": "qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
|
||||||
},
|
},
|
||||||
"Qwen-14B-int8-Chat": {
|
"Qwen-14B-int8-Chat": {
|
||||||
"hf": "Qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||||
"ms": "qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
|
||||||
},
|
},
|
||||||
"Qwen-14B-int4-Chat": {
|
"Qwen-14B-int4-Chat": {
|
||||||
"hf": "Qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||||
"ms": "qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
|
||||||
|
},
|
||||||
|
"Qwen-72B-int8-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
|
||||||
|
},
|
||||||
|
"Qwen-72B-int4-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
module="c_attn",
|
module="c_attn",
|
||||||
|
@ -411,8 +501,8 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Skywork-13B-Base": {
|
"Skywork-13B-Base": {
|
||||||
"hf": "Skywork/Skywork-13B-base",
|
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||||
"ms": "skywork/Skywork-13B-base",
|
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -421,39 +511,58 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna1.5-7B-Chat": {
|
"Vicuna1.5-7B-Chat": {
|
||||||
"hf": "lmsys/vicuna-7b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||||
"ms": "AI-ModelScope/vicuna-7b-v1.5",
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
|
||||||
},
|
},
|
||||||
"Vicuna1.5-13B-Chat": {
|
"Vicuna1.5-13B-Chat": {
|
||||||
"hf": "lmsys/vicuna-13b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||||
"ms": "Xorbits/vicuna-13b-v1.5",
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="vicuna"
|
template="vicuna"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"XuanYuan-70B": {
|
||||||
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
|
||||||
|
},
|
||||||
|
"XuanYuan-70B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
|
||||||
|
},
|
||||||
|
"XuanYuan-70B-int8-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
|
||||||
|
},
|
||||||
|
"XuanYuan-70B-int4-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="xuanyuan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"XVERSE-7B": {
|
"XVERSE-7B": {
|
||||||
"hf": "xverse/XVERSE-7B",
|
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||||
"ms": "xverse/XVERSE-7B",
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
|
||||||
},
|
},
|
||||||
"XVERSE-13B": {
|
"XVERSE-13B": {
|
||||||
"hf": "xverse/XVERSE-13B",
|
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||||
"ms": "xverse/XVERSE-13B",
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
|
||||||
},
|
},
|
||||||
"XVERSE-65B": {
|
"XVERSE-65B": {
|
||||||
"hf": "xverse/XVERSE-65B",
|
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||||
"ms": "xverse/XVERSE-65B",
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
|
||||||
},
|
},
|
||||||
"XVERSE-7B-Chat": {
|
"XVERSE-7B-Chat": {
|
||||||
"hf": "xverse/XVERSE-7B-Chat",
|
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||||
"ms": "xverse/XVERSE-7B-Chat",
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
|
||||||
},
|
},
|
||||||
"XVERSE-13B-Chat": {
|
"XVERSE-13B-Chat": {
|
||||||
"hf": "xverse/XVERSE-13B-Chat",
|
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||||
"ms": "xverse/XVERSE-13B-Chat",
|
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="xverse"
|
template="xverse"
|
||||||
|
@ -463,12 +572,12 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yayi-7B": {
|
"Yayi-7B": {
|
||||||
"hf": "wenge-research/yayi-7b-llama2",
|
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||||
"ms": "AI-ModelScope/yayi-7b-llama2",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
|
||||||
},
|
},
|
||||||
"Yayi-13B": {
|
"Yayi-13B": {
|
||||||
"hf": "wenge-research/yayi-13b-llama2",
|
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||||
"ms": "AI-ModelScope/yayi-13b-llama2",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="yayi"
|
template="yayi"
|
||||||
|
@ -478,20 +587,28 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yi-6B": {
|
"Yi-6B": {
|
||||||
"hf": "01-ai/Yi-6B",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
||||||
"ms": "01ai/Yi-6B",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
|
||||||
},
|
},
|
||||||
"Yi-34B": {
|
"Yi-34B": {
|
||||||
"hf": "01-ai/Yi-34B",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
||||||
"ms": "01ai/Yi-34B",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
|
||||||
|
},
|
||||||
|
"Yi-6B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
|
||||||
},
|
},
|
||||||
"Yi-34B-Chat": {
|
"Yi-34B-Chat": {
|
||||||
"hf": "01-ai/Yi-34B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||||
"ms": "01ai/Yi-34B-Chat",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
|
||||||
|
},
|
||||||
|
"Yi-6B-int8-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||||
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits"
|
||||||
},
|
},
|
||||||
"Yi-34B-int8-Chat": {
|
"Yi-34B-int8-Chat": {
|
||||||
"hf": "01-ai/Yi-34B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||||
"ms": "01ai/Yi-34B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="yi"
|
template="yi"
|
||||||
|
@ -501,12 +618,12 @@ register_model_group(
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Zephyr-7B-Alpha-Chat": {
|
"Zephyr-7B-Alpha-Chat": {
|
||||||
"hf": "HuggingFaceH4/zephyr-7b-alpha",
|
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
||||||
"ms": "AI-ModelScope/zephyr-7b-alpha",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
|
||||||
},
|
},
|
||||||
"Zephyr-7B-Beta-Chat": {
|
"Zephyr-7B-Beta-Chat": {
|
||||||
"hf": "HuggingFaceH4/zephyr-7b-beta",
|
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||||
"ms": "modelscope/zephyr-7b-beta",
|
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="zephyr"
|
template="zephyr"
|
||||||
|
|
|
@ -23,6 +23,7 @@ except ImportError:
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
|
@ -67,14 +68,18 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> str:
|
def get_current_device() -> torch.device:
|
||||||
import accelerate
|
import accelerate
|
||||||
if accelerate.utils.is_xpu_available():
|
if accelerate.utils.is_xpu_available():
|
||||||
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
elif accelerate.utils.is_npu_available() or torch.cuda.is_available():
|
elif accelerate.utils.is_npu_available():
|
||||||
return os.environ.get("LOCAL_RANK", "0")
|
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
else:
|
else:
|
||||||
return "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
|
@ -117,3 +122,23 @@ def torch_gc() -> None:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
|
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||||
|
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from modelscope import snapshot_download # type: ignore
|
||||||
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
|
model_args.model_name_or_path = snapshot_download(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=model_args.cache_dir
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
|
||||||
|
|
||||||
|
def use_modelscope() -> bool:
|
||||||
|
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||||
|
|
|
@ -18,6 +18,7 @@ _flash_attn2_available = is_package_available("flash_attn") and get_package_vers
|
||||||
_jieba_available = is_package_available("jieba")
|
_jieba_available = is_package_available("jieba")
|
||||||
_matplotlib_available = is_package_available("matplotlib")
|
_matplotlib_available = is_package_available("matplotlib")
|
||||||
_nltk_available = is_package_available("nltk")
|
_nltk_available = is_package_available("nltk")
|
||||||
|
_requests_available = is_package_available("requests")
|
||||||
_rouge_available = is_package_available("rouge_chinese")
|
_rouge_available = is_package_available("rouge_chinese")
|
||||||
_starlette_available = is_package_available("sse_starlette")
|
_starlette_available = is_package_available("sse_starlette")
|
||||||
_uvicorn_available = is_package_available("uvicorn")
|
_uvicorn_available = is_package_available("uvicorn")
|
||||||
|
@ -43,6 +44,10 @@ def is_nltk_available():
|
||||||
return _nltk_available
|
return _nltk_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_requests_available():
|
||||||
|
return _requests_available
|
||||||
|
|
||||||
|
|
||||||
def is_rouge_available():
|
def is_rouge_available():
|
||||||
return _rouge_available
|
return _rouge_available
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ class DatasetAttr:
|
||||||
dataset_sha1: Optional[str] = None
|
dataset_sha1: Optional[str] = None
|
||||||
system_prompt: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
|
folder: Optional[str] = None
|
||||||
ranking: Optional[bool] = False
|
ranking: Optional[bool] = False
|
||||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||||
|
|
||||||
|
@ -184,6 +185,7 @@ class DataArguments:
|
||||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||||
|
|
||||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||||
|
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||||
dataset_attr.system_prompt = prompt_list[i]
|
dataset_attr.system_prompt = prompt_list[i]
|
||||||
|
|
|
@ -118,9 +118,9 @@ class RLHFArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."}
|
metadata={"help": "The number of bits to quantize the reward model."}
|
||||||
)
|
)
|
||||||
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,10 +141,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
)
|
)
|
||||||
neft_alpha: Optional[float] = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
|
||||||
)
|
|
||||||
export_dir: Optional[str] = field(
|
export_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
metadata={"help": "Path to the directory to save the exported model."}
|
||||||
|
|
|
@ -8,8 +8,8 @@ class ModelArguments:
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||||
"""
|
"""
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
metadata={"help": "Path to pretrained model or model identifier "
|
metadata={"help": "Path to pretrained model or model identifier from \
|
||||||
"from huggingface.co/models or modelscope.cn/models."}
|
huggingface.co/models or modelscope.cn/models."}
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
|
@ -87,7 +87,7 @@ def init_adapter(
|
||||||
|
|
||||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
target_modules = find_all_linear_modules(model)
|
||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
|
||||||
|
@ -102,6 +102,9 @@ def init_adapter(
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -23,13 +22,12 @@ try:
|
||||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.model.adapter import init_adapter
|
from llmtuner.model.adapter import init_adapter
|
||||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
@ -39,10 +37,10 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
require_version("transformers>=4.36.0", "To fix: pip install transformers>=4.36.0")
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,7 +48,7 @@ def load_model_and_tokenizer(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
add_valuehead: Optional[bool] = False
|
||||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained model and tokenizer.
|
Loads pretrained model and tokenizer.
|
||||||
|
@ -58,6 +56,8 @@ def load_model_and_tokenizer(
|
||||||
Support both training and inference.
|
Support both training and inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try_download_model_from_ms(model_args)
|
||||||
|
|
||||||
config_kwargs = {
|
config_kwargs = {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
@ -65,8 +65,6 @@ def load_model_and_tokenizer(
|
||||||
"token": model_args.hf_hub_token
|
"token": model_args.hf_hub_token
|
||||||
}
|
}
|
||||||
|
|
||||||
try_download_model_from_ms(model_args)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
|
@ -125,45 +123,42 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if not is_flash_attn2_available():
|
||||||
if is_flash_attn2_available():
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
elif getattr(config, "model_type", None) == "qwen":
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
else:
|
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support FlashAttention.")
|
setattr(config, "attn_implementation", "flash_attention_2")
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
|
||||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
|
||||||
|
|
||||||
# Set shift short attention (S^2-Attn)
|
# Set shift short attention (S^2-Attn)
|
||||||
if is_trainable and model_args.shift_attn:
|
if is_trainable and model_args.shift_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
# if getattr(config, "model_type", None) == "llama":
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
# setattr(config, "group_size_ratio", 0.25)
|
||||||
else:
|
# logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
logger.warning("Current model does not support shift short attention.")
|
# else:
|
||||||
|
# logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
|
# Quantization configurations (using gptq or awq)
|
||||||
|
if getattr(config, "quantization_config", None):
|
||||||
|
if model_args.quantization_bit is not None: # remove bnb quantization
|
||||||
|
model_args.quantization_bit = None
|
||||||
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
quantization_config = getattr(config, "quantization_config", None)
|
||||||
|
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library)
|
# Quantization configurations (using bitsandbytes library)
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
if getattr(config, "quantization_config", None):
|
|
||||||
raise ValueError("Remove `quantization_bit` if you are using a quantized model.")
|
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
if model_args.quantization_bit == 8:
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
config_kwargs["load_in_8bit"] = True
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
if model_args.quantization_bit == 4:
|
if model_args.quantization_bit == 4:
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
config_kwargs["load_in_4bit"] = True
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
|
@ -183,6 +178,9 @@ def load_model_and_tokenizer(
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Resize token embeddings
|
||||||
|
resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
# Disable custom generate method (for Qwen and Baichuan2)
|
# Disable custom generate method (for Qwen and Baichuan2)
|
||||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
@ -203,12 +201,12 @@ def load_model_and_tokenizer(
|
||||||
# Initialize adapters
|
# Initialize adapters
|
||||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||||
model = model.train() if is_trainable else model.eval()
|
|
||||||
|
|
||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if stage in ["rm", "ppo"]:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
||||||
vhead_path = (
|
vhead_path = (
|
||||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||||
|
@ -222,6 +220,9 @@ def load_model_and_tokenizer(
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False) # fix all model params
|
||||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||||
|
model.eval()
|
||||||
|
else:
|
||||||
|
model.train()
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
|
@ -232,16 +233,3 @@ def load_model_and_tokenizer(
|
||||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def try_download_model_from_ms(model_args):
|
|
||||||
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and not os.path.exists(model_args.model_name_or_path):
|
|
||||||
try:
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
revision = model_args.model_revision
|
|
||||||
if revision == 'main':
|
|
||||||
revision = 'master'
|
|
||||||
model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, revision)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(f'You are using `USE_MODELSCOPE_HUB=1` but you have no modelscope sdk installed. '
|
|
||||||
f'Please install it by `pip install modelscope -U`') from e
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,10 +23,10 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||||
"""
|
"""
|
||||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
if getattr(model, "quantization_method", None): # already set on current device
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
|
@ -42,18 +43,18 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
return model.cuda()
|
return model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
model: "PreTrainedModel",
|
|
||||||
quantization_bit: Optional[int] = None
|
|
||||||
) -> List[str]:
|
|
||||||
r"""
|
r"""
|
||||||
Finds all available modules to apply lora.
|
Finds all available modules to apply lora.
|
||||||
"""
|
"""
|
||||||
if quantization_bit is not None:
|
quantization_method = getattr(model, "quantization_method", None)
|
||||||
import bitsandbytes as bnb
|
if quantization_method is None:
|
||||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
|
||||||
else:
|
|
||||||
linear_cls = torch.nn.Linear
|
linear_cls = torch.nn.Linear
|
||||||
|
elif quantization_method == "bitsandbytes":
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||||
|
else:
|
||||||
|
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||||
|
|
||||||
output_layer_names = ["lm_head"]
|
output_layer_names = ["lm_head"]
|
||||||
if model.config.model_type == "chatglm":
|
if model.config.model_type == "chatglm":
|
||||||
|
@ -147,17 +148,6 @@ def prepare_model_for_training(
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
logger.info("Upcasting weights in layernorm in float32.")
|
logger.info("Upcasting weights in layernorm in float32.")
|
||||||
|
|
||||||
if finetuning_args.neft_alpha > 1e-6:
|
|
||||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
|
||||||
if module.training:
|
|
||||||
dims = torch.tensor(output.size(1) * output.size(2))
|
|
||||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
|
||||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
|
||||||
return output
|
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
|
||||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
|
||||||
|
|
||||||
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
|
@ -181,3 +171,18 @@ def prepare_model_for_training(
|
||||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
r"""
|
||||||
|
Resize token embeddings.
|
||||||
|
"""
|
||||||
|
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||||
|
logger.warning("Current model does not support resizing token embeddings.")
|
||||||
|
return
|
||||||
|
|
||||||
|
old_vocab_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
if len(tokenizer) != old_vocab_size:
|
||||||
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||||
|
new_vocab_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
logger.info("Resized token embeddings from {} to {}.".format(old_vocab_size, new_vocab_size))
|
||||||
|
|
|
@ -25,11 +25,11 @@ def run_dpo(
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = DPODataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=4,
|
pad_to_multiple_of=8,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ def run_dpo(
|
||||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||||
ref_model = model
|
ref_model = model
|
||||||
else:
|
else:
|
||||||
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
|
|
|
@ -3,12 +3,11 @@ import sys
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
@ -16,7 +15,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
@ -66,7 +65,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
if reward_model is not None:
|
if finetuning_args.reward_model_type == "full":
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
if not (
|
if not (
|
||||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
|
@ -200,7 +199,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
|
@ -208,7 +207,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response: torch.Tensor = unwrapped_model.generate(
|
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config,
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
**batch
|
**batch
|
||||||
|
@ -217,7 +216,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
if self.finetuning_args.upcast_layernorm:
|
if self.finetuning_args.upcast_layernorm:
|
||||||
restore_layernorm(self.model, layernorm_params)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
|
|
||||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
query = batch["input_ids"].detach().cpu()
|
||||||
|
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
|
@ -242,17 +242,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
|
|
||||||
|
Both inputs and outputs are put on CPU.
|
||||||
"""
|
"""
|
||||||
if self.reward_model is None:
|
if self.finetuning_args.reward_model_type == "api":
|
||||||
|
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||||
|
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||||
|
return get_rewards_from_server(self.reward_model, messages)
|
||||||
|
|
||||||
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
|
reward_model = self.model
|
||||||
|
else:
|
||||||
|
reward_model = self.reward_model
|
||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
|
||||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
|
@ -261,7 +270,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
if self.reward_model is None:
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
return rewards
|
return rewards
|
||||||
|
@ -351,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||||
" zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
self._save(output_dir, state_dict={})
|
self._save(output_dir, state_dict={})
|
||||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
|
||||||
|
file = os.path.join(output_dir, filename)
|
||||||
|
if os.path.isfile(file):
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
self.model.save_checkpoint(output_dir) # wrapped model
|
self.model.save_checkpoint(output_dir) # wrapped model
|
||||||
|
|
|
@ -1,10 +1,24 @@
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from llmtuner.extras.packages import is_requests_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
if is_requests_available():
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
payload = {"model": "model", "messages": messages}
|
||||||
|
response = requests.post(server_url, json=payload, headers=headers)
|
||||||
|
rewards = json.loads(response.text)["scores"]
|
||||||
|
return torch.Tensor(rewards)
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if target == "reward": # save default head temporarily
|
||||||
|
|
|
@ -28,14 +28,14 @@ def run_ppo(
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
|
||||||
# Create reference model and reward model
|
# Create reference model and reward model
|
||||||
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||||
|
|
||||||
# Create ppo config
|
# Create ppo config
|
||||||
|
|
|
@ -22,7 +22,7 @@ def run_pt(
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
|
|
@ -25,9 +25,9 @@ def run_rm(
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
|
|
|
@ -26,7 +26,7 @@ def run_sft(
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
|
@ -34,7 +34,7 @@ def run_sft(
|
||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
|
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Literal, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
@ -35,7 +35,7 @@ def create_modelcard_and_push(
|
||||||
def create_ref_model(
|
def create_ref_model(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
stage: Literal["ppo", "dpo"]
|
add_valuehead: Optional[bool] = False
|
||||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||||
r"""
|
r"""
|
||||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||||
|
@ -51,13 +51,17 @@ def create_ref_model(
|
||||||
))
|
))
|
||||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
ref_model, _ = load_model_and_tokenizer(
|
||||||
|
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
|
)
|
||||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||||
else:
|
else:
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
ref_model = None
|
ref_model = None
|
||||||
else:
|
else:
|
||||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
ref_model, _ = load_model_and_tokenizer(
|
||||||
|
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
|
)
|
||||||
logger.info("Created reference model from the model itself.")
|
logger.info("Created reference model from the model itself.")
|
||||||
|
|
||||||
return ref_model
|
return ref_model
|
||||||
|
@ -71,7 +75,11 @@ def create_reward_model(
|
||||||
r"""
|
r"""
|
||||||
Creates reward model for PPO training.
|
Creates reward model for PPO training.
|
||||||
"""
|
"""
|
||||||
if finetuning_args.reward_model_type == "lora":
|
if finetuning_args.reward_model_type == "api":
|
||||||
|
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||||
|
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||||
|
return finetuning_args.reward_model
|
||||||
|
elif finetuning_args.reward_model_type == "lora":
|
||||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||||
if "default" in name:
|
if "default" in name:
|
||||||
|
@ -93,7 +101,9 @@ def create_reward_model(
|
||||||
))
|
))
|
||||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
reward_model, _ = load_model_and_tokenizer(
|
||||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||||
|
)
|
||||||
|
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
|
@ -11,18 +11,17 @@ from transformers.utils import (
|
||||||
ADAPTER_SAFE_WEIGHTS_NAME
|
ADAPTER_SAFE_WEIGHTS_NAME
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import (
|
from llmtuner.extras.constants import (
|
||||||
DEFAULT_MODULE,
|
DEFAULT_MODULE,
|
||||||
DEFAULT_TEMPLATE,
|
DEFAULT_TEMPLATE,
|
||||||
SUPPORTED_MODELS,
|
SUPPORTED_MODELS,
|
||||||
ALL_OFFICIAL_MODELS,
|
TRAINING_STAGES,
|
||||||
TRAINING_STAGES
|
DownloadSource
|
||||||
)
|
)
|
||||||
|
from llmtuner.extras.misc import use_modelscope
|
||||||
from llmtuner.hparams.data_args import DATA_CONFIG
|
from llmtuner.hparams.data_args import DATA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
DEFAULT_DATA_DIR = "data"
|
DEFAULT_DATA_DIR = "data"
|
||||||
DEFAULT_SAVE_DIR = "saves"
|
DEFAULT_SAVE_DIR = "saves"
|
||||||
|
@ -66,10 +65,15 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||||
|
|
||||||
def get_model_path(model_name: str) -> str:
|
def get_model_path(model_name: str) -> str:
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
cached_path = user_config["path_dict"].get(model_name, None)
|
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, [])
|
||||||
if cached_path in ALL_OFFICIAL_MODELS.get(model_name, []):
|
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "")
|
||||||
cached_path = None
|
if (
|
||||||
return cached_path or SUPPORTED_MODELS.get(model_name, "")
|
use_modelscope()
|
||||||
|
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
|
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||||
|
): # replace path
|
||||||
|
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def get_prefix(model_name: str) -> str:
|
def get_prefix(model_name: str) -> str:
|
||||||
|
|
|
@ -65,16 +65,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||||
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
train_on_prompt = gr.Checkbox(value=False)
|
train_on_prompt = gr.Checkbox(value=False)
|
||||||
upcast_layernorm = gr.Checkbox(value=False)
|
upcast_layernorm = gr.Checkbox(value=False)
|
||||||
|
|
||||||
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
|
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, train_on_prompt, upcast_layernorm})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
||||||
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
neftune_alpha=neftune_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
||||||
))
|
))
|
||||||
|
|
||||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||||
|
|
|
@ -333,7 +333,7 @@ LOCALES = {
|
||||||
"info": "学习率预热采用的步数。"
|
"info": "学习率预热采用的步数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"neft_alpha": {
|
"neftune_alpha": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "NEFTune Alpha",
|
"label": "NEFTune Alpha",
|
||||||
"info": "Magnitude of noise adding to embedding vectors."
|
"info": "Magnitude of noise adding to embedding vectors."
|
||||||
|
|
|
@ -119,7 +119,7 @@ class Runner:
|
||||||
logging_steps=get("train.logging_steps"),
|
logging_steps=get("train.logging_steps"),
|
||||||
save_steps=get("train.save_steps"),
|
save_steps=get("train.save_steps"),
|
||||||
warmup_steps=get("train.warmup_steps"),
|
warmup_steps=get("train.warmup_steps"),
|
||||||
neft_alpha=get("train.neft_alpha"),
|
neftune_noise_alpha=get("train.neftune_alpha"),
|
||||||
train_on_prompt=get("train.train_on_prompt"),
|
train_on_prompt=get("train.train_on_prompt"),
|
||||||
upcast_layernorm=get("train.upcast_layernorm"),
|
upcast_layernorm=get("train.upcast_layernorm"),
|
||||||
lora_rank=get("train.lora_rank"),
|
lora_rank=get("train.lora_rank"),
|
||||||
|
|
Loading…
Reference in New Issue