forked from p04798526/LLaMA-Factory-Mirror
support DPO training (2305.18290)
This commit is contained in:
parent
685dae4eff
commit
3ec4351cfd
60
README.md
60
README.md
|
@ -12,6 +12,8 @@
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
|
||||||
|
|
||||||
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
|
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
|
||||||
|
|
||||||
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
|
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
|
||||||
|
@ -54,24 +56,18 @@
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||||
|
|
||||||
> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
||||||
> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc.
|
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
||||||
- Full-parameter tuning
|
| ---------------------- | -------------- | ----------------- | ---- | ----- |
|
||||||
- Partial-parameter tuning
|
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
| Reward Model Training | | | ✅ | ✅ |
|
||||||
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
|
| PPO Training | | | ✅ | ✅ |
|
||||||
- Full-parameter tuning
|
| DPO Training | ✅ | | ✅ | ✅ |
|
||||||
- Partial-parameter tuning
|
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
- [RLHF](https://arxiv.org/abs/2203.02155)
|
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
|
|
||||||
## Provided Datasets
|
## Provided Datasets
|
||||||
|
|
||||||
|
@ -88,7 +84,6 @@
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self-cognition (zh)](data/self_cognition.json)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||||
|
@ -103,7 +98,7 @@
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- For reward modelling:
|
- For reward modelling or DPO training:
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
@ -139,7 +134,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
|
||||||
### Dependence Installation (optional)
|
### Dependence Installation (optional)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git lfs install
|
|
||||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||||
conda create -n llama_etuning python=3.10
|
conda create -n llama_etuning python=3.10
|
||||||
conda activate llama_etuning
|
conda activate llama_etuning
|
||||||
|
@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||||
|
|
||||||
Currently the web UI only supports training on **a single GPU**.
|
Currently the web UI only supports training on **a single GPU**.
|
||||||
|
|
||||||
### (Continually) Pre-Training
|
### Pre-Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
|
@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### PPO Training (RLHF)
|
### PPO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--plot_loss
|
--plot_loss
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### DPO Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--model_name_or_path path_to_your_model \
|
||||||
|
--do_train \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--resume_lora_training False \
|
||||||
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
|
--output_dir path_to_dpo_checkpoint \
|
||||||
|
--per_device_train_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 1000 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
|
```
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
68
README_zh.md
68
README_zh.md
|
@ -12,7 +12,9 @@
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。
|
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-training)(实验性功能)。
|
||||||
|
|
||||||
|
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。
|
||||||
|
|
||||||
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
|
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
|
||||||
|
|
||||||
|
@ -54,41 +56,34 @@
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||||
|
|
||||||
> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。
|
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
|
||||||
> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。
|
- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。
|
||||||
|
|
||||||
## 微调方法
|
## 训练方法
|
||||||
|
|
||||||
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||||
- 全参数微调
|
| ---------- | ---------- | ----------- | ---- | ----- |
|
||||||
- 部分参数微调
|
| 预训练 | ✅ | ✅ | ✅ | ✅ |
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
| 指令监督微调 | ✅ | ✅ | ✅ | ✅ |
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
| 奖励模型训练 | | | ✅ | ✅ |
|
||||||
- [指令监督微调](https://arxiv.org/abs/2109.01652)
|
| PPO 训练 | | | ✅ | ✅ |
|
||||||
- 全参数微调
|
| DPO 训练 | ✅ | | ✅ | ✅ |
|
||||||
- 部分参数微调
|
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
|
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
|
|
||||||
## 数据集
|
## 数据集
|
||||||
|
|
||||||
- 用于二次预训练:
|
- 用于预训练:
|
||||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||||
- 用于指令监督微调:
|
- 用于指令监督微调:
|
||||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self-cognition (zh)](data/self_cognition.json)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||||
|
@ -103,7 +98,7 @@
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- 用于奖励模型训练:
|
- 用于奖励模型或 DPO 训练:
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
@ -139,7 +134,6 @@ huggingface-cli login
|
||||||
### 环境搭建(可跳过)
|
### 环境搭建(可跳过)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git lfs install
|
|
||||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||||
conda create -n llama_etuning python=3.10
|
conda create -n llama_etuning python=3.10
|
||||||
conda activate llama_etuning
|
conda activate llama_etuning
|
||||||
|
@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||||
|
|
||||||
目前网页 UI 仅支持**单卡训练**。
|
目前网页 UI 仅支持**单卡训练**。
|
||||||
|
|
||||||
### 二次预训练
|
### 预训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
|
@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### RLHF 训练
|
### PPO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--plot_loss
|
--plot_loss
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### DPO 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--model_name_or_path path_to_your_model \
|
||||||
|
--do_train \
|
||||||
|
--dataset comparison_gpt4_zh \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--resume_lora_training False \
|
||||||
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
|
--output_dir path_to_dpo_checkpoint \
|
||||||
|
--per_device_train_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 1000 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
|
```
|
||||||
|
|
||||||
### 多 GPU 分布式训练
|
### 多 GPU 分布式训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -49,26 +49,6 @@
|
||||||
"history": "history"
|
"history": "history"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"refgpt_zh_p1": {
|
|
||||||
"file_name": "refgpt_zh_50k_p1.json",
|
|
||||||
"file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"query": "input",
|
|
||||||
"response": "output",
|
|
||||||
"history": "history"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"refgpt_zh_p2": {
|
|
||||||
"file_name": "refgpt_zh_50k_p2.json",
|
|
||||||
"file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "instruction",
|
|
||||||
"query": "input",
|
|
||||||
"response": "output",
|
|
||||||
"history": "history"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"lima": {
|
"lima": {
|
||||||
"file_name": "lima.json",
|
"file_name": "lima.json",
|
||||||
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
|
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
|
||||||
|
|
520932
data/refgpt_zh_50k_p1.json
520932
data/refgpt_zh_50k_p1.json
File diff suppressed because it is too large
Load Diff
506158
data/refgpt_zh_50k_p2.json
506158
data/refgpt_zh_50k_p2.json
File diff suppressed because one or more lines are too long
|
@ -3,7 +3,7 @@ transformers>=4.29.1
|
||||||
datasets>=2.12.0
|
datasets>=2.12.0
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.4.0
|
peft>=0.4.0
|
||||||
trl>=0.4.7
|
trl>=0.5.0
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tiktoken
|
tiktoken
|
||||||
|
@ -16,4 +16,3 @@ pydantic==1.10.11
|
||||||
fastapi==0.95.1
|
fastapi==0.95.1
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
huggingface_hub
|
|
|
@ -7,7 +7,7 @@ def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
# Visit http://localhost:8000/docs for document.
|
print("Visit http://localhost:8000/docs for API document.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -18,7 +18,6 @@ class ChatModel:
|
||||||
self.model = self.model.eval() # change to eval mode
|
self.model = self.model.eval() # change to eval mode
|
||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
|
||||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
|
@ -53,7 +52,7 @@ class ChatModel:
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
stopping_criteria=get_stopping_criteria(self.stop_ids)
|
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||||
))
|
))
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
|
|
|
@ -46,7 +46,6 @@ def preprocess_dataset(
|
||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
result["labels"] = result["input_ids"].copy()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
|
@ -95,24 +94,22 @@ def preprocess_dataset(
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples):
|
def preprocess_pairwise_dataset(examples):
|
||||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
for query, response, history, prefix in construct_example(examples):
|
for query, response, history, prefix in construct_example(examples):
|
||||||
source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
|
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
|
||||||
source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
|
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(prompt_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||||
if len(accept_ids) > data_args.max_target_length:
|
if len(chosen_ids) > data_args.max_target_length:
|
||||||
accept_ids = accept_ids[:data_args.max_target_length]
|
chosen_ids = chosen_ids[:data_args.max_target_length]
|
||||||
if len(reject_ids) > data_args.max_target_length:
|
if len(rejected_ids) > data_args.max_target_length:
|
||||||
reject_ids = reject_ids[:data_args.max_target_length]
|
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
accept_ids = source_ids + accept_ids
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
reject_ids = source_ids + reject_ids
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
model_inputs["accept_ids"].append(accept_ids)
|
|
||||||
model_inputs["reject_ids"].append(reject_ids)
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def print_supervised_dataset_example(example):
|
def print_supervised_dataset_example(example):
|
||||||
|
@ -124,10 +121,12 @@ def preprocess_dataset(
|
||||||
], skip_special_tokens=False)))
|
], skip_special_tokens=False)))
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example):
|
def print_pairwise_dataset_example(example):
|
||||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||||
|
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||||
|
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example):
|
def print_unsupervised_dataset_example(example):
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
|
|
@ -7,10 +7,16 @@ from datetime import timedelta
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from transformers.trainer_utils import has_length
|
from transformers.trainer_utils import has_length
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
def __init__(self, runner=None):
|
||||||
|
@ -38,6 +44,9 @@ class LogCallback(TrainerCallback):
|
||||||
self.in_training = True
|
self.in_training = True
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.max_steps = state.max_steps
|
self.max_steps = state.max_steps
|
||||||
|
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
|
||||||
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
|
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||||
|
|
||||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||||
|
|
||||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
from transformers import (
|
||||||
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
|
LogitsProcessor,
|
||||||
|
LogitsProcessorList,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList
|
||||||
|
)
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ class Template:
|
||||||
prefix: Optional[str] = None
|
prefix: Optional[str] = None
|
||||||
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
|
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
|
||||||
r"""
|
r"""
|
||||||
Aligns inputs to a special format.
|
Aligns inputs to the standard format.
|
||||||
"""
|
"""
|
||||||
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
||||||
history = history if (history and self.use_history) else []
|
history = history if (history and self.use_history) else []
|
||||||
|
@ -92,28 +92,32 @@ class Template:
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
|
Turn 0: bos + prefix + sep + query resp + eos
|
||||||
|
Turn t: sep + bos + query resp + eos
|
||||||
"""
|
"""
|
||||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||||
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||||
encoded_pairs = []
|
encoded_pairs = []
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
if turn_idx != 0:
|
if turn_idx == 0:
|
||||||
prefix_ids = sep_ids
|
if prefix: # has prefix
|
||||||
elif prefix:
|
prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
|
|
||||||
else:
|
else:
|
||||||
prefix_ids = []
|
prefix_ids = bos_ids
|
||||||
|
else:
|
||||||
|
prefix_ids = sep_ids + bos_ids
|
||||||
|
|
||||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
|
||||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||||
encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids))
|
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
def _convert_inputs_to_ids(
|
def _convert_inputs_to_ids(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
context: List[Union[str, Dict[str, str]]],
|
context: List[Union[str, Dict[str, str]]],
|
||||||
query: Optional[str] = ""
|
query: Optional[str] = "",
|
||||||
|
idx: Optional[str] = ""
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
r"""
|
r"""
|
||||||
Converts context to token ids.
|
Converts context to token ids.
|
||||||
|
@ -127,6 +131,7 @@ class Template:
|
||||||
for elem in context:
|
for elem in context:
|
||||||
if isinstance(elem, str):
|
if isinstance(elem, str):
|
||||||
elem = elem.replace("{{query}}", query, 1)
|
elem = elem.replace("{{query}}", query, 1)
|
||||||
|
elem = elem.replace("{{idx}}", idx, 1)
|
||||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||||
elif isinstance(elem, dict):
|
elif isinstance(elem, dict):
|
||||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||||
|
@ -146,10 +151,12 @@ class Llama2Template(Template):
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
|
Turn 0: bos + prefix + query resp + eos
|
||||||
|
Turn t: bos + query resp + eos
|
||||||
"""
|
"""
|
||||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||||
encoded_pairs = []
|
encoded_pairs = []
|
||||||
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
|
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
if turn_idx == 0: # llama2 template has not sep_ids
|
if turn_idx == 0: # llama2 template has not sep_ids
|
||||||
query = prefix[0] + query
|
query = prefix[0] + query
|
||||||
|
@ -187,10 +194,11 @@ def get_template_and_fix_tokenizer(
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None: # inplace method
|
if len(template.stop_words): # inplace method
|
||||||
if len(template.stop_words):
|
|
||||||
tokenizer.eos_token = template.stop_words[0]
|
tokenizer.eos_token = template.stop_words[0]
|
||||||
else:
|
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
|
if tokenizer.eos_token_id is None:
|
||||||
tokenizer.eos_token = "<|endoftext|>"
|
tokenizer.eos_token = "<|endoftext|>"
|
||||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
|
@ -422,12 +430,13 @@ register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix=[],
|
prefix=[],
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<reserved_102>"},
|
{"token": "<reserved_102>"}, # user token (a little difference in position)
|
||||||
"{{query}}",
|
"{{query}}"
|
||||||
{"token": "<reserved_103>"}
|
|
||||||
],
|
],
|
||||||
sep=[],
|
sep=[],
|
||||||
stop_words=[],
|
stop_words=[
|
||||||
|
"<reserved_103>" # assistant token
|
||||||
|
],
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -440,7 +449,8 @@ register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
prefix=[
|
prefix=[
|
||||||
{"token": "<|system|>"},
|
{"token": "<|system|>"},
|
||||||
"\n"
|
"\n",
|
||||||
|
{"token": "<|end|>"}
|
||||||
],
|
],
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<|user|>"},
|
{"token": "<|user|>"},
|
||||||
|
@ -466,7 +476,8 @@ register_template(
|
||||||
name="chatml",
|
name="chatml",
|
||||||
prefix=[
|
prefix=[
|
||||||
{"token": "<|im_start|>"},
|
{"token": "<|im_start|>"},
|
||||||
"system\nYou are a helpful assistant."
|
"system\nYou are a helpful assistant.",
|
||||||
|
{"token": "<|im_end|>"}
|
||||||
],
|
],
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<|im_start|>"},
|
{"token": "<|im_start|>"},
|
||||||
|
@ -484,3 +495,23 @@ register_template(
|
||||||
],
|
],
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="chatglm2",
|
||||||
|
prefix=[
|
||||||
|
{"token": "[gMASK]"},
|
||||||
|
{"token": "sop"}
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||||
|
],
|
||||||
|
sep=[
|
||||||
|
"\n\n"
|
||||||
|
],
|
||||||
|
stop_words=[],
|
||||||
|
use_history=True
|
||||||
|
)
|
||||||
|
|
|
@ -24,7 +24,7 @@ class DatasetAttr:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
"""
|
"""
|
||||||
template: str = field(
|
template: str = field(
|
||||||
|
|
|
@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments:
|
class FinetuningArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||||
|
@ -14,7 +14,7 @@ class FinetuningArguments:
|
||||||
)
|
)
|
||||||
num_hidden_layers: Optional[int] = field(
|
num_hidden_layers: Optional[int] = field(
|
||||||
default=32,
|
default=32,
|
||||||
metadata={"help": "Number of decoder blocks in the model. \
|
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
|
||||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||||
|
@ -25,16 +25,16 @@ class FinetuningArguments:
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||||
)
|
)
|
||||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||||
Baichuan choices: [\"mlp\", \"self_attn\"], \
|
Baichuan choices: [\"mlp\", \"self_attn\"], \
|
||||||
Qwen choices: [\"mlp\", \"attn\"], \
|
Qwen choices: [\"mlp\", \"attn\"], \
|
||||||
InternLM, XVERSE choices: the same as LLaMA."}
|
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
default=8,
|
default=8,
|
||||||
|
@ -51,11 +51,15 @@ class FinetuningArguments:
|
||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default="q_proj,v_proj",
|
default="q_proj,v_proj",
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||||
InternLM, XVERSE choices: the same as LLaMA."}
|
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||||
|
)
|
||||||
|
dpo_beta: Optional[float] = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -72,14 +76,14 @@ class FinetuningArguments:
|
||||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
f.write(json_string)
|
f.write(json_string)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_json(cls, json_path: str):
|
def load_from_json(cls, json_path: str):
|
||||||
"""Creates an instance from the content of `json_path`."""
|
r"""Creates an instance from the content of `json_path`."""
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
return cls(**json.loads(text))
|
return cls(**json.loads(text))
|
||||||
|
|
|
@ -4,10 +4,10 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneralArguments:
|
class GeneralArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to which stage we are going to perform.
|
Arguments pertaining to which stage we are going to perform.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneratingArguments:
|
class GeneratingArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
do_sample: Optional[bool] = field(
|
do_sample: Optional[bool] = field(
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from huggingface_hub.hf_api import HfFolder
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||||
"""
|
"""
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
|
@ -64,12 +63,11 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
)
|
)
|
||||||
hf_hub_token : Optional[str] = field(
|
hf_auth_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
@ -77,5 +75,6 @@ class ModelArguments:
|
||||||
if self.quantization_bit is not None:
|
if self.quantization_bit is not None:
|
||||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
if self.use_auth_token == True and self.hf_hub_token != None:
|
if self.use_auth_token == True and self.hf_auth_token is not None:
|
||||||
HfFolder.save_token(self.hf_hub_token)
|
from huggingface_hub.hf_api import HfFolder # lazy load
|
||||||
|
HfFolder.save_token(self.hf_auth_token)
|
||||||
|
|
|
@ -39,7 +39,7 @@ def init_adapter(
|
||||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ check_min_version("4.29.1")
|
||||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
@ -52,9 +52,6 @@ def load_model_and_tokenizer(
|
||||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||||
|
|
||||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
|
||||||
"RM and PPO training can only be performed with the LoRA method."
|
|
||||||
|
|
||||||
config_kwargs = {
|
config_kwargs = {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
|
@ -132,8 +129,6 @@ def load_model_and_tokenizer(
|
||||||
})
|
})
|
||||||
|
|
||||||
if stage == "ppo": # load reward model
|
if stage == "ppo": # load reward model
|
||||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
|
||||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llmtuner.hparams import (
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
|
@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
|
||||||
|
|
||||||
def parse_train_args(
|
def parse_train_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
|
]:
|
||||||
parser = HfArgumentParser((
|
parser = HfArgumentParser((
|
||||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def parse_infer_args(
|
def parse_infer_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments
|
||||||
|
]:
|
||||||
parser = HfArgumentParser((
|
parser = HfArgumentParser((
|
||||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(
|
def get_train_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
) -> Tuple[
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
|
]:
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
|
@ -68,7 +95,7 @@ def get_train_args(
|
||||||
data_args.init_for_training()
|
data_args.init_for_training()
|
||||||
|
|
||||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if training_args.do_train and training_args.predict_with_generate:
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
@ -76,6 +103,15 @@ def get_train_args(
|
||||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
|
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
|
||||||
|
|
||||||
|
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||||
|
raise ValueError("PPO and DPO stage can only be performed at training.")
|
||||||
|
|
||||||
|
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||||
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
|
@ -133,12 +169,17 @@ def get_train_args(
|
||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
transformers.set_seed(training_args.seed)
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
return model_args, data_args, training_args, finetuning_args, general_args
|
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(
|
def get_infer_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments
|
||||||
|
]:
|
||||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
|
|
@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PeftTrainer(Seq2SeqTrainer):
|
class PeftModelMixin:
|
||||||
r"""
|
r"""
|
||||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
def __init__(self) -> None: # for type checking
|
||||||
super().__init__(**kwargs)
|
self.model: PreTrainedModel = None
|
||||||
self.finetuning_args = finetuning_args
|
self.tokenizer: "PreTrainedTokenizer" = None
|
||||||
self._remove_log()
|
self.args: "Seq2SeqTrainingArguments" = None
|
||||||
|
self.finetuning_args: "FinetuningArguments" = None
|
||||||
def _remove_log(self):
|
self.state: "TrainerState" = None
|
||||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
raise AssertionError("Mixin should not be initialized.")
|
||||||
logger.warning("Previous log file in this folder will be deleted.")
|
|
||||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||||
else: # freeze/full-tuning
|
else: # freeze/full-tuning
|
||||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||||
|
r"""
|
||||||
|
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||||
|
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from llmtuner.tuner.dpo.workflow import run_dpo
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for pairwise data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||||
|
padded_labels = []
|
||||||
|
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||||
|
if self.tokenizer.padding_side == "left":
|
||||||
|
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||||
|
else:
|
||||||
|
start, end = prompt_len, answer_len
|
||||||
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||||
|
padded_tensor[start:end] = feature[start:end]
|
||||||
|
padded_labels.append(padded_tensor)
|
||||||
|
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||||
|
the last n examples represent rejected examples.
|
||||||
|
"""
|
||||||
|
concatenated_features = []
|
||||||
|
label_positions = []
|
||||||
|
for key in ("chosen_ids", "rejected_ids"):
|
||||||
|
for feature in features:
|
||||||
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||||
|
concatenated_features.append({
|
||||||
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
|
"attention_mask": [1] * (prompt_len + answer_len)
|
||||||
|
})
|
||||||
|
label_positions.append((prompt_len, answer_len))
|
||||||
|
|
||||||
|
batch = self.tokenizer.pad(
|
||||||
|
concatenated_features,
|
||||||
|
padding=self.padding,
|
||||||
|
max_length=self.max_length,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors=self.return_tensors,
|
||||||
|
)
|
||||||
|
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||||
|
return batch
|
|
@ -0,0 +1,75 @@
|
||||||
|
import torch
|
||||||
|
from collections import defaultdict
|
||||||
|
from peft import PeftModel
|
||||||
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||||
|
from transformers import Trainer
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
|
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
self.generating_args = generating_args
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
|
self.padding_value = 0
|
||||||
|
self.beta = finetuning_args.dpo_beta
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
Trainer.__init__(self, **kwargs)
|
||||||
|
if ref_model is not None:
|
||||||
|
if hasattr(self, "accelerator"):
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
else:
|
||||||
|
raise AttributeError("Please update `transformers`.")
|
||||||
|
|
||||||
|
def concatenated_forward(
|
||||||
|
self,
|
||||||
|
model: Optional[torch.nn.Module] = None,
|
||||||
|
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||||
|
if not torch.is_grad_enabled():
|
||||||
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
|
|
||||||
|
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||||
|
with unwrapped_model.disable_adapter():
|
||||||
|
all_logits: torch.Tensor = self.model(
|
||||||
|
batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
return_dict=True
|
||||||
|
).logits.to(torch.float32)
|
||||||
|
else:
|
||||||
|
all_logits: torch.Tensor = model(
|
||||||
|
batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
return_dict=True
|
||||||
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
|
if not torch.is_grad_enabled():
|
||||||
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
all_logps = self._get_batch_logps(
|
||||||
|
all_logits,
|
||||||
|
batch["labels"],
|
||||||
|
average_log_prob=False
|
||||||
|
)
|
||||||
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
|
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||||
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from peft import PeftModel
|
||||||
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
|
||||||
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
|
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||||
|
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
|
def run_dpo(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
|
):
|
||||||
|
dataset = get_dataset(model_args, data_args)
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
|
data_collator = DPODataCollatorWithPadding(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = DPOPeftTrainer(
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
generating_args=generating_args,
|
||||||
|
ref_model=ref_model,
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**split_dataset(dataset, data_args, training_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
trainer.save_model()
|
||||||
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
|
@ -10,7 +10,7 @@ from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
@ -33,16 +33,17 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
self,
|
self,
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["LogCallback"],
|
callbacks: List["LogCallback"],
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
self.generating_args = generating_args
|
||||||
self.log_callback = callbacks[0]
|
self.log_callback = callbacks[0]
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self._remove_log()
|
|
||||||
|
|
||||||
def ppo_train(self, max_target_length: int) -> None:
|
def ppo_train(self, max_target_length: int) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -72,14 +73,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = {
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
"top_k": 0.0,
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
"top_p": 1.0,
|
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||||
"do_sample": True,
|
|
||||||
"pad_token_id": self.tokenizer.pad_token_id,
|
|
||||||
"eos_token_id": self.tokenizer.eos_token_id,
|
|
||||||
"logits_processor": get_logits_processor()
|
|
||||||
}
|
|
||||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
# Inspired by:
|
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||||
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from trl import PPOConfig
|
from trl import PPOConfig
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from typing import Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
|
@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
|
@ -24,6 +22,7 @@ def run_ppo(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
|
@ -42,8 +41,9 @@ def run_ppo(
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
total_train_batch_size = \
|
total_train_batch_size = (
|
||||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
)
|
||||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
|
@ -56,6 +56,7 @@ def run_ppo(
|
||||||
ppo_trainer = PPOPeftTrainer(
|
ppo_trainer = PPOPeftTrainer(
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
generating_args=generating_args,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -67,8 +68,10 @@ def run_ppo(
|
||||||
lr_scheduler=lr_scheduler
|
lr_scheduler=lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
ppo_trainer.save_state() # must be after save_model
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
|
@ -25,10 +24,7 @@ def run_pt(
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
tokenizer=tokenizer,
|
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PeftTrainer(
|
trainer = PeftTrainer(
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Sequence
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
|
@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||||
the last n examples represent rejected examples.
|
the last n examples represent rejected examples.
|
||||||
"""
|
"""
|
||||||
features = [
|
features = [
|
||||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
{
|
||||||
for key in ("accept_ids", "reject_ids") for feature in features
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
|
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
||||||
|
}
|
||||||
|
for key in ("chosen_ids", "rejected_ids") for feature in features
|
||||||
]
|
]
|
||||||
return super().__call__(features)
|
return super().__call__(features)
|
||||||
|
|
|
@ -79,7 +79,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||||
|
|
||||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||||
return padded_tensor.contiguous()
|
return padded_tensor.contiguous() # in contiguous memory
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||||
|
@ -13,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
|
@ -21,6 +21,7 @@ def run_sft(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
|
@ -50,13 +51,9 @@ def run_sft(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = {
|
gen_kwargs = generating_args.to_dict()
|
||||||
"do_sample": True,
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
"top_p": 0.7,
|
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
||||||
"max_new_tokens": data_args.max_target_length + 1,
|
|
||||||
"temperature": 0.95,
|
|
||||||
"logits_processor": get_logits_processor()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|
|
@ -1,35 +1,47 @@
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.pt import run_pt
|
from llmtuner.tuner.pt import run_pt
|
||||||
from llmtuner.tuner.sft import run_sft
|
from llmtuner.tuner.sft import run_sft
|
||||||
from llmtuner.tuner.rm import run_rm
|
from llmtuner.tuner.rm import run_rm
|
||||||
from llmtuner.tuner.ppo import run_ppo
|
from llmtuner.tuner.ppo import run_ppo
|
||||||
|
from llmtuner.tuner.dpo import run_dpo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
callbacks = [LogCallback()] if callbacks is None else callbacks + [LogCallback()]
|
||||||
|
|
||||||
if general_args.stage == "pt":
|
if general_args.stage == "pt":
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif general_args.stage == "sft":
|
elif general_args.stage == "sft":
|
||||||
run_sft(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
elif general_args.stage == "rm":
|
elif general_args.stage == "rm":
|
||||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif general_args.stage == "ppo":
|
elif general_args.stage == "ppo":
|
||||||
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
elif general_args.stage == "dpo":
|
||||||
|
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown task.")
|
||||||
|
|
||||||
|
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||||
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
|
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||||
|
try:
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
except:
|
||||||
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,7 +4,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from typing import Generator, List, Optional, Tuple
|
from typing import Generator, List, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||||
|
|
Loading…
Reference in New Issue