support DPO training (2305.18290)

This commit is contained in:
hiyouga 2023-08-11 03:02:53 +08:00
parent 685dae4eff
commit 3ec4351cfd
34 changed files with 513 additions and 1027300 deletions

View File

@ -12,6 +12,8 @@
## Changelog ## Changelog
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model. [23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. [23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
@ -54,24 +56,18 @@
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. - **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. - For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches ## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
- Full-parameter tuning | ---------------------- | -------------- | ----------------- | ---- | ----- |
- Partial-parameter tuning | Pre-Training | ✅ | ✅ | ✅ | ✅ |
- [LoRA](https://arxiv.org/abs/2106.09685) | Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
- [QLoRA](https://arxiv.org/abs/2305.14314) | Reward Model Training | | | ✅ | ✅ |
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) | PPO Training | | | ✅ | ✅ |
- Full-parameter tuning | DPO Training | ✅ | | ✅ | ✅ |
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## Provided Datasets ## Provided Datasets
@ -88,7 +84,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json) - [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -103,7 +98,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling: - For reward modelling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@ -139,7 +134,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional) ### Dependence Installation (optional)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
Currently the web UI only supports training on **a single GPU**. Currently the web UI only supports training on **a single GPU**.
### (Continually) Pre-Training ### Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### PPO Training (RLHF) ### PPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss --plot_loss
``` ```
### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### Distributed Training ### Distributed Training
```bash ```bash

View File

@ -12,7 +12,9 @@
## 更新日志 ## 更新日志
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。 [23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-training)(实验性功能)。
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。 [23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。
@ -54,41 +56,34 @@
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。 - **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。 - 对于所有“基座”Base模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna`任意值。但“对话”Chat模型请务必使用对应的模板。
## 微调方法 ## 训练方法
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
- 全参数微调 | ---------- | ---------- | ----------- | ---- | ----- |
- 部分参数微调 | 预训练 | ✅ | ✅ | ✅ | ✅ |
- [LoRA](https://arxiv.org/abs/2106.09685) | 指令监督微调 | ✅ | ✅ | ✅ | ✅ |
- [QLoRA](https://arxiv.org/abs/2305.14314) | 奖励模型训练 | | | ✅ | ✅ |
- [指令监督微调](https://arxiv.org/abs/2109.01652) | PPO 训练 | | | ✅ | ✅ |
- 全参数微调 | DPO 训练 | ✅ | | ✅ | ✅ |
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [人类反馈的强化学习RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## 数据集 ## 数据集
- 用于二次预训练: - 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt) - [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调: - 用于指令监督微调
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json) - [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -103,7 +98,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练: - 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@ -139,7 +134,6 @@ huggingface-cli login
### 环境搭建(可跳过) ### 环境搭建(可跳过)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
目前网页 UI 仅支持**单卡训练**。 目前网页 UI 仅支持**单卡训练**。
### 二次预训练 ### 预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### RLHF 训练 ### PPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss --plot_loss
``` ```
### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练 ### 多 GPU 分布式训练
```bash ```bash

View File

@ -49,26 +49,6 @@
"history": "history" "history": "history"
} }
}, },
"refgpt_zh_p1": {
"file_name": "refgpt_zh_50k_p1.json",
"file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"refgpt_zh_p2": {
"file_name": "refgpt_zh_50k_p2.json",
"file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"lima": { "lima": {
"file_name": "lima.json", "file_name": "lima.json",
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37", "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -3,7 +3,7 @@ transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0
trl>=0.4.7 trl>=0.5.0
scipy scipy
sentencepiece sentencepiece
tiktoken tiktoken
@ -16,4 +16,3 @@ pydantic==1.10.11
fastapi==0.95.1 fastapi==0.95.1
sse-starlette sse-starlette
matplotlib matplotlib
huggingface_hub

View File

@ -7,7 +7,7 @@ def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
# Visit http://localhost:8000/docs for document. print("Visit http://localhost:8000/docs for API document.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -18,7 +18,6 @@ class ChatModel:
self.model = self.model.eval() # change to eval mode self.model = self.model.eval() # change to eval mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen) self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args( def process_args(
@ -53,7 +52,7 @@ class ChatModel:
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
stopping_criteria=get_stopping_criteria(self.stop_ids) stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
)) ))
if max_length: if max_length:

View File

@ -46,7 +46,6 @@ def preprocess_dataset(
k: [t[i: i + block_size] for i in range(0, total_length, block_size)] k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items() for k, t in concatenated_examples.items()
} }
result["labels"] = result["input_ids"].copy()
return result return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
@ -95,24 +94,22 @@ def preprocess_dataset(
return model_inputs return model_inputs
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, prefix in construct_example(examples):
source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
if len(source_ids) > data_args.max_source_length: if len(prompt_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] prompt_ids = prompt_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length: if len(chosen_ids) > data_args.max_target_length:
accept_ids = accept_ids[:data_args.max_target_length] chosen_ids = chosen_ids[:data_args.max_target_length]
if len(reject_ids) > data_args.max_target_length: if len(rejected_ids) > data_args.max_target_length:
reject_ids = reject_ids[:data_args.max_target_length] rejected_ids = rejected_ids[:data_args.max_target_length]
accept_ids = source_ids + accept_ids model_inputs["prompt_ids"].append(prompt_ids)
reject_ids = source_ids + reject_ids model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs return model_inputs
def print_supervised_dataset_example(example): def print_supervised_dataset_example(example):
@ -124,10 +121,12 @@ def preprocess_dataset(
], skip_special_tokens=False))) ], skip_special_tokens=False)))
def print_pairwise_dataset_example(example): def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"])) print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"])) print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example): def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))

View File

@ -7,10 +7,16 @@ from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length from transformers.trainer_utils import has_length
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
logger = get_logger(__name__)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
@ -38,6 +44,9 @@ class LogCallback(TrainerCallback):
self.in_training = True self.in_training = True
self.start_time = time.time() self.start_time = time.time()
self.max_steps = state.max_steps self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""

View File

@ -1,10 +1,12 @@
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin" VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json" FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]

View File

@ -1,7 +1,11 @@
import torch import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import (
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList LogitsProcessor,
LogitsProcessorList,
StoppingCriteria,
StoppingCriteriaList
)
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES

View File

@ -61,7 +61,7 @@ class Template:
prefix: Optional[str] = None prefix: Optional[str] = None
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
r""" r"""
Aligns inputs to a special format. Aligns inputs to the standard format.
""" """
prefix = [prefix] if prefix else self.prefix # use prefix if provided prefix = [prefix] if prefix else self.prefix # use prefix if provided
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
@ -92,28 +92,32 @@ class Template:
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
""" """
bos_ids, eos_ids = self._get_special_ids(tokenizer) bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = [] encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx != 0: if turn_idx == 0:
prefix_ids = sep_ids if prefix: # has prefix
elif prefix: prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
else: else:
prefix_ids = [] prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids)) encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs return encoded_pairs
def _convert_inputs_to_ids( def _convert_inputs_to_ids(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]], context: List[Union[str, Dict[str, str]]],
query: Optional[str] = "" query: Optional[str] = "",
idx: Optional[str] = ""
) -> List[int]: ) -> List[int]:
r""" r"""
Converts context to token ids. Converts context to token ids.
@ -127,6 +131,7 @@ class Template:
for elem in context: for elem in context:
if isinstance(elem, str): if isinstance(elem, str):
elem = elem.replace("{{query}}", query, 1) elem = elem.replace("{{query}}", query, 1)
elem = elem.replace("{{idx}}", idx, 1)
token_ids = token_ids + tokenizer.encode(elem, **kwargs) token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict): elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
@ -146,10 +151,12 @@ class Llama2Template(Template):
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
""" """
bos_ids, eos_ids = self._get_special_ids(tokenizer) bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = [] encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has not sep_ids if turn_idx == 0: # llama2 template has not sep_ids
query = prefix[0] + query query = prefix[0] + query
@ -187,10 +194,11 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None) template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
if tokenizer.eos_token_id is None: # inplace method if len(template.stop_words): # inplace method
if len(template.stop_words):
tokenizer.eos_token = template.stop_words[0] tokenizer.eos_token = template.stop_words[0]
else: logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>" tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info("Add eos token: {}".format(tokenizer.eos_token))
@ -422,12 +430,13 @@ register_template(
name="baichuan", name="baichuan",
prefix=[], prefix=[],
prompt=[ prompt=[
{"token": "<reserved_102>"}, {"token": "<reserved_102>"}, # user token (a little difference in position)
"{{query}}", "{{query}}"
{"token": "<reserved_103>"}
], ],
sep=[], sep=[],
stop_words=[], stop_words=[
"<reserved_103>" # assistant token
],
use_history=True use_history=True
) )
@ -440,7 +449,8 @@ register_template(
name="starchat", name="starchat",
prefix=[ prefix=[
{"token": "<|system|>"}, {"token": "<|system|>"},
"\n" "\n",
{"token": "<|end|>"}
], ],
prompt=[ prompt=[
{"token": "<|user|>"}, {"token": "<|user|>"},
@ -466,7 +476,8 @@ register_template(
name="chatml", name="chatml",
prefix=[ prefix=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"system\nYou are a helpful assistant." "system\nYou are a helpful assistant.",
{"token": "<|im_end|>"}
], ],
prompt=[ prompt=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
@ -484,3 +495,23 @@ register_template(
], ],
use_history=True use_history=True
) )
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"}
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
)

View File

@ -24,7 +24,7 @@ class DatasetAttr:
@dataclass @dataclass
class DataArguments: class DataArguments:
""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field( template: str = field(

View File

@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class FinetuningArguments: class FinetuningArguments:
""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
@ -14,7 +14,7 @@ class FinetuningArguments:
) )
num_hidden_layers: Optional[int] = field( num_hidden_layers: Optional[int] = field(
default=32, default=32,
metadata={"help": "Number of decoder blocks in the model. \ metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \
@ -25,16 +25,16 @@ class FinetuningArguments:
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."} metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
) )
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"], \ Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \ Qwen choices: [\"mlp\", \"attn\"], \
InternLM, XVERSE choices: the same as LLaMA."} LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
@ -51,11 +51,15 @@ class FinetuningArguments:
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
InternLM, XVERSE choices: the same as LLaMA."} LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
) )
def __post_init__(self): def __post_init__(self):
@ -72,14 +76,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f: with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string) f.write(json_string)
@classmethod @classmethod
def load_from_json(cls, json_path: str): def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`.""" r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))

View File

@ -4,10 +4,10 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class GeneralArguments: class GeneralArguments:
""" r"""
Arguments pertaining to which stage we are going to perform. Arguments pertaining to which stage we are going to perform.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."} metadata={"help": "Which stage will be performed in training."}
) )

View File

@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class GeneratingArguments: class GeneratingArguments:
""" r"""
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(

View File

@ -1,12 +1,11 @@
import torch import torch
from typing import Literal, Optional from typing import Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from huggingface_hub.hf_api import HfFolder
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
@ -64,12 +63,11 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
) )
hf_hub_token : Optional[str] = field( hf_auth_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."} metadata={"help": "Auth token to log in with Hugging Face Hub."}
) )
def __post_init__(self): def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@ -77,5 +75,6 @@ class ModelArguments:
if self.quantization_bit is not None: if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_hub_token != None: if self.use_auth_token == True and self.hf_auth_token is not None:
HfFolder.save_token(self.hf_hub_token) from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)

View File

@ -39,7 +39,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "none" and is_trainable: if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.") raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
model = model.float() model = model.float()

View File

@ -34,7 +34,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
def load_model_and_tokenizer( def load_model_and_tokenizer(
@ -52,9 +52,6 @@ def load_model_and_tokenizer(
logger.warning("Checkpoint is not found at evaluation, load the original model.") logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = { config_kwargs = {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
@ -132,8 +129,6 @@ def load_model_and_tokenizer(
}) })
if stage == "ppo": # load reward model if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model)) logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."

View File

@ -19,7 +19,7 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None): def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args( def parse_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
def parse_infer_args( def parse_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
def get_train_args( def get_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args) ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@ -68,7 +95,7 @@ def get_train_args(
data_args.init_for_training() data_args.init_for_training()
if general_args.stage != "sft" and training_args.predict_with_generate: if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if training_args.do_train and training_args.predict_with_generate: if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.") raise ValueError("`predict_with_generate` cannot be set as True while training.")
@ -76,6 +103,15 @@ def get_train_args(
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stage can only be performed at training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if training_args.max_steps == -1 and data_args.streaming: if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.") raise ValueError("Please specify `max_steps` in streaming mode.")
@ -133,12 +169,17 @@ def get_train_args(
# Set seed before initializing model. # Set seed before initializing model.
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args( def get_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":

View File

@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer): class PeftModelMixin:
r""" r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
""" """
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): def __init__(self) -> None: # for type checking
super().__init__(**kwargs) self.model: PreTrainedModel = None
self.finetuning_args = finetuning_args self.tokenizer: "PreTrainedTokenizer" = None
self._remove_log() self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
def _remove_log(self): self.state: "TrainerState" = None
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): raise AssertionError("Mixin should not be initialized.")
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r""" r"""
@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint) load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@ -0,0 +1 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@ -0,0 +1,75 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import Trainer
from trl import DPOTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if ref_model is not None:
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
raise AttributeError("Please update `transformers`.")
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits: torch.Tensor = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits: torch.Tensor = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

View File

@ -0,0 +1,59 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

View File

@ -10,7 +10,7 @@ from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
@ -18,7 +18,7 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -33,16 +33,17 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self, self,
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"], callbacks: List["LogCallback"],
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
self.args = training_args self.args = training_args
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0] self.log_callback = callbacks[0]
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
self._remove_log()
def ppo_train(self, max_target_length: int) -> None: def ppo_train(self, max_target_length: int) -> None:
r""" r"""
@ -72,14 +73,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = self.generating_args.to_dict()
"top_k": 0.0, gen_kwargs["logits_processor"] = get_logits_processor()
"top_p": 1.0, gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)

View File

@ -1,11 +1,9 @@
# Inspired by: # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math import math
from typing import TYPE_CHECKING
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo( def run_ppo(
@ -24,6 +22,7 @@ def run_ppo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
@ -42,8 +41,9 @@ def run_ppo(
) )
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = \ total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
@ -56,6 +56,7 @@ def run_ppo(
ppo_trainer = PPOPeftTrainer( ppo_trainer = PPOPeftTrainer(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks, callbacks=callbacks,
config=ppo_config, config=ppo_config,
model=model, model=model,
@ -67,8 +68,10 @@ def run_ppo(
lr_scheduler=lr_scheduler lr_scheduler=lr_scheduler
) )
# Training
if training_args.do_train:
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model() ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -2,10 +2,9 @@
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
@ -25,10 +24,7 @@ def run_pt(
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = PeftTrainer(

View File

@ -1,8 +1,10 @@
import torch import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
features = [ features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} {
for key in ("accept_ids", "reject_ids") for feature in features "input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
] ]
return super().__call__(features) return super().__call__(features)

View File

@ -79,7 +79,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() return padded_tensor.contiguous() # in contiguous memory
def save_predictions( def save_predictions(
self, self,

View File

@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.metric import ComputeMetrics
@ -13,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft( def run_sft(
@ -21,6 +21,7 @@ def run_sft(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
@ -50,13 +51,9 @@ def run_sft(
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = generating_args.to_dict()
"do_sample": True, gen_kwargs["logits_processor"] = get_logits_processor()
"top_p": 0.7, gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
"max_new_tokens": data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
# Training # Training
if training_args.do_train: if training_args.do_train:

View File

@ -1,35 +1,47 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks callbacks = [LogCallback()] if callbacks is None else callbacks + [LogCallback()]
if general_args.stage == "pt": if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft": elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, callbacks) run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm": elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks) run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo": elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks) run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _ = get_train_args(args) model_args, _, training_args, finetuning_args, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
try:
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4,7 +4,7 @@ import threading
import time import time
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from typing import Generator, List, Optional, Tuple from typing import Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE