support DPO training (2305.18290)
This commit is contained in:
parent
685dae4eff
commit
3ec4351cfd
60
README.md
60
README.md
|
@ -12,6 +12,8 @@
|
|||
|
||||
## Changelog
|
||||
|
||||
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
|
||||
|
||||
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
|
||||
|
||||
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
|
||||
|
@ -54,24 +56,18 @@
|
|||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||
|
||||
> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
||||
> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc.
|
||||
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
||||
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
- Full-parameter tuning
|
||||
- Partial-parameter tuning
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
|
||||
- Full-parameter tuning
|
||||
- Partial-parameter tuning
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [RLHF](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
||||
| ---------------------- | -------------- | ----------------- | ---- | ----- |
|
||||
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
|
||||
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
|
||||
| Reward Model Training | | | ✅ | ✅ |
|
||||
| PPO Training | | | ✅ | ✅ |
|
||||
| DPO Training | ✅ | | ✅ | ✅ |
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
|
@ -88,7 +84,6 @@
|
|||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
|
@ -103,7 +98,7 @@
|
|||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward modelling:
|
||||
- For reward modelling or DPO training:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
@ -139,7 +134,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
|
|||
### Dependence Installation (optional)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
|
@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
|||
|
||||
Currently the web UI only supports training on **a single GPU**.
|
||||
|
||||
### (Continually) Pre-Training
|
||||
### Pre-Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
|
@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--fp16
|
||||
```
|
||||
|
||||
### PPO Training (RLHF)
|
||||
### PPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--plot_loss
|
||||
```
|
||||
|
||||
### DPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### Distributed Training
|
||||
|
||||
```bash
|
||||
|
|
68
README_zh.md
68
README_zh.md
|
@ -12,7 +12,9 @@
|
|||
|
||||
## 更新日志
|
||||
|
||||
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。
|
||||
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-training)(实验性功能)。
|
||||
|
||||
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。
|
||||
|
||||
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
|
||||
|
||||
|
@ -54,41 +56,34 @@
|
|||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||
|
||||
> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。
|
||||
> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。
|
||||
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
|
||||
- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。
|
||||
|
||||
## 微调方法
|
||||
## 训练方法
|
||||
|
||||
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
- 全参数微调
|
||||
- 部分参数微调
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [指令监督微调](https://arxiv.org/abs/2109.01652)
|
||||
- 全参数微调
|
||||
- 部分参数微调
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||
| ---------- | ---------- | ----------- | ---- | ----- |
|
||||
| 预训练 | ✅ | ✅ | ✅ | ✅ |
|
||||
| 指令监督微调 | ✅ | ✅ | ✅ | ✅ |
|
||||
| 奖励模型训练 | | | ✅ | ✅ |
|
||||
| PPO 训练 | | | ✅ | ✅ |
|
||||
| DPO 训练 | ✅ | | ✅ | ✅ |
|
||||
|
||||
## 数据集
|
||||
|
||||
- 用于二次预训练:
|
||||
- 用于预训练:
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- 用于指令监督微调:
|
||||
- 用于指令监督微调:
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
|
@ -103,7 +98,7 @@
|
|||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- 用于奖励模型训练:
|
||||
- 用于奖励模型或 DPO 训练:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
@ -139,7 +134,6 @@ huggingface-cli login
|
|||
### 环境搭建(可跳过)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
|
@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
|||
|
||||
目前网页 UI 仅支持**单卡训练**。
|
||||
|
||||
### 二次预训练
|
||||
### 预训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
|
@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--fp16
|
||||
```
|
||||
|
||||
### RLHF 训练
|
||||
### PPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--plot_loss
|
||||
```
|
||||
|
||||
### DPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### 多 GPU 分布式训练
|
||||
|
||||
```bash
|
||||
|
|
|
@ -49,26 +49,6 @@
|
|||
"history": "history"
|
||||
}
|
||||
},
|
||||
"refgpt_zh_p1": {
|
||||
"file_name": "refgpt_zh_50k_p1.json",
|
||||
"file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"refgpt_zh_p2": {
|
||||
"file_name": "refgpt_zh_50k_p2.json",
|
||||
"file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"lima": {
|
||||
"file_name": "lima.json",
|
||||
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
|
||||
|
|
520932
data/refgpt_zh_50k_p1.json
520932
data/refgpt_zh_50k_p1.json
File diff suppressed because it is too large
Load Diff
506158
data/refgpt_zh_50k_p2.json
506158
data/refgpt_zh_50k_p2.json
File diff suppressed because one or more lines are too long
|
@ -3,7 +3,7 @@ transformers>=4.29.1
|
|||
datasets>=2.12.0
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.4.7
|
||||
trl>=0.5.0
|
||||
scipy
|
||||
sentencepiece
|
||||
tiktoken
|
||||
|
@ -16,4 +16,3 @@ pydantic==1.10.11
|
|||
fastapi==0.95.1
|
||||
sse-starlette
|
||||
matplotlib
|
||||
huggingface_hub
|
|
@ -7,7 +7,7 @@ def main():
|
|||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
# Visit http://localhost:8000/docs for document.
|
||||
print("Visit http://localhost:8000/docs for API document.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -18,7 +18,6 @@ class ChatModel:
|
|||
self.model = self.model.eval() # change to eval mode
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||
|
||||
def process_args(
|
||||
|
@ -53,7 +52,7 @@ class ChatModel:
|
|||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor(),
|
||||
stopping_criteria=get_stopping_criteria(self.stop_ids)
|
||||
stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||
))
|
||||
|
||||
if max_length:
|
||||
|
|
|
@ -46,7 +46,6 @@ def preprocess_dataset(
|
|||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
|
@ -95,24 +94,22 @@ def preprocess_dataset(
|
|||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
|
||||
source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(accept_ids) > data_args.max_target_length:
|
||||
accept_ids = accept_ids[:data_args.max_target_length]
|
||||
if len(reject_ids) > data_args.max_target_length:
|
||||
reject_ids = reject_ids[:data_args.max_target_length]
|
||||
if len(prompt_ids) > data_args.max_source_length:
|
||||
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||
if len(chosen_ids) > data_args.max_target_length:
|
||||
chosen_ids = chosen_ids[:data_args.max_target_length]
|
||||
if len(rejected_ids) > data_args.max_target_length:
|
||||
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||
|
||||
accept_ids = source_ids + accept_ids
|
||||
reject_ids = source_ids + reject_ids
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
|
@ -124,10 +121,12 @@ def preprocess_dataset(
|
|||
], skip_special_tokens=False)))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
|
|
|
@ -7,10 +7,16 @@ from datetime import timedelta
|
|||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import has_length
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
|
@ -38,6 +44,9 @@ class LogCallback(TrainerCallback):
|
|||
self.in_training = True
|
||||
self.start_time = time.time()
|
||||
self.max_steps = state.max_steps
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
IGNORE_INDEX = -100
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import torch
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
|
||||
from transformers import (
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class Template:
|
|||
prefix: Optional[str] = None
|
||||
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Aligns inputs to a special format.
|
||||
Aligns inputs to the standard format.
|
||||
"""
|
||||
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
||||
history = history if (history and self.use_history) else []
|
||||
|
@ -92,28 +92,32 @@ class Template:
|
|||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: bos + prefix + sep + query resp + eos
|
||||
Turn t: sep + bos + query resp + eos
|
||||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx != 0:
|
||||
prefix_ids = sep_ids
|
||||
elif prefix:
|
||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
|
||||
if turn_idx == 0:
|
||||
if prefix: # has prefix
|
||||
prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids
|
||||
else:
|
||||
prefix_ids = []
|
||||
prefix_ids = bos_ids
|
||||
else:
|
||||
prefix_ids = sep_ids + bos_ids
|
||||
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids))
|
||||
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
||||
return encoded_pairs
|
||||
|
||||
def _convert_inputs_to_ids(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
context: List[Union[str, Dict[str, str]]],
|
||||
query: Optional[str] = ""
|
||||
query: Optional[str] = "",
|
||||
idx: Optional[str] = ""
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts context to token ids.
|
||||
|
@ -127,6 +131,7 @@ class Template:
|
|||
for elem in context:
|
||||
if isinstance(elem, str):
|
||||
elem = elem.replace("{{query}}", query, 1)
|
||||
elem = elem.replace("{{idx}}", idx, 1)
|
||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
|
@ -146,10 +151,12 @@ class Llama2Template(Template):
|
|||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: bos + prefix + query resp + eos
|
||||
Turn t: bos + query resp + eos
|
||||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
encoded_pairs = []
|
||||
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
|
||||
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string."
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0: # llama2 template has not sep_ids
|
||||
query = prefix[0] + query
|
||||
|
@ -187,10 +194,11 @@ def get_template_and_fix_tokenizer(
|
|||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
if tokenizer.eos_token_id is None: # inplace method
|
||||
if len(template.stop_words):
|
||||
if len(template.stop_words): # inplace method
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
else:
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
|
@ -422,12 +430,13 @@ register_template(
|
|||
name="baichuan",
|
||||
prefix=[],
|
||||
prompt=[
|
||||
{"token": "<reserved_102>"},
|
||||
"{{query}}",
|
||||
{"token": "<reserved_103>"}
|
||||
{"token": "<reserved_102>"}, # user token (a little difference in position)
|
||||
"{{query}}"
|
||||
],
|
||||
sep=[],
|
||||
stop_words=[],
|
||||
stop_words=[
|
||||
"<reserved_103>" # assistant token
|
||||
],
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
@ -440,7 +449,8 @@ register_template(
|
|||
name="starchat",
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n"
|
||||
"\n",
|
||||
{"token": "<|end|>"}
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
|
@ -466,7 +476,8 @@ register_template(
|
|||
name="chatml",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\nYou are a helpful assistant."
|
||||
"system\nYou are a helpful assistant.",
|
||||
{"token": "<|im_end|>"}
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|im_start|>"},
|
||||
|
@ -484,3 +495,23 @@ register_template(
|
|||
],
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"}
|
||||
],
|
||||
prompt=[
|
||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||
],
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[],
|
||||
use_history=True
|
||||
)
|
||||
|
|
|
@ -24,7 +24,7 @@ class DatasetAttr:
|
|||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
template: str = field(
|
||||
|
|
|
@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field
|
|||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||
|
@ -14,7 +14,7 @@ class FinetuningArguments:
|
|||
)
|
||||
num_hidden_layers: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model. \
|
||||
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
|
@ -25,16 +25,16 @@ class FinetuningArguments:
|
|||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"], \
|
||||
Qwen choices: [\"mlp\", \"attn\"], \
|
||||
InternLM, XVERSE choices: the same as LLaMA."}
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
|
@ -51,11 +51,15 @@ class FinetuningArguments:
|
|||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
InternLM, XVERSE choices: the same as LLaMA."}
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -72,14 +76,14 @@ class FinetuningArguments:
|
|||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
"""Creates an instance from the content of `json_path`."""
|
||||
r"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
|
|
@ -4,10 +4,10 @@ from dataclasses import dataclass, field
|
|||
|
||||
@dataclass
|
||||
class GeneralArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to which stage we are going to perform.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
|
|||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
do_sample: Optional[bool] = field(
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import torch
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from huggingface_hub.hf_api import HfFolder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
|
@ -64,12 +63,11 @@ class ModelArguments:
|
|||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
hf_hub_token : Optional[str] = field(
|
||||
hf_auth_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
@ -77,5 +75,6 @@ class ModelArguments:
|
|||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.use_auth_token == True and self.hf_hub_token != None:
|
||||
HfFolder.save_token(self.hf_hub_token)
|
||||
if self.use_auth_token == True and self.hf_auth_token is not None:
|
||||
from huggingface_hub.hf_api import HfFolder # lazy load
|
||||
HfFolder.save_token(self.hf_auth_token)
|
||||
|
|
|
@ -39,7 +39,7 @@ def init_adapter(
|
|||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ check_min_version("4.29.1")
|
|||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
||||
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
@ -52,9 +52,6 @@ def load_model_and_tokenizer(
|
|||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with the LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
|
@ -132,8 +129,6 @@ def load_model_and_tokenizer(
|
|||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
|
|
|
@ -19,7 +19,7 @@ from llmtuner.hparams import (
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
|
@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
|
|||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
|
@ -68,7 +95,7 @@ def get_train_args(
|
|||
data_args.init_for_training()
|
||||
|
||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
@ -76,6 +103,15 @@ def get_train_args(
|
|||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
|
||||
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stage can only be performed at training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
|
@ -133,12 +169,17 @@ def get_train_args(
|
|||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, general_args
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
|
|
|
@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
|
|||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
class PeftModelMixin:
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
|
||||
def _remove_log(self):
|
||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||
def __init__(self) -> None: # for type checking
|
||||
self.model: PreTrainedModel = None
|
||||
self.tokenizer: "PreTrainedTokenizer" = None
|
||||
self.args: "Seq2SeqTrainingArguments" = None
|
||||
self.finetuning_args: "FinetuningArguments" = None
|
||||
self.state: "TrainerState" = None
|
||||
raise AssertionError("Mixin should not be initialized.")
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
|
@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
|
||||
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from llmtuner.tuner.dpo.workflow import run_dpo
|
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||
padded_labels = []
|
||||
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||
if self.tokenizer.padding_side == "left":
|
||||
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||
else:
|
||||
start, end = prompt_len, answer_len
|
||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||
padded_tensor[start:end] = feature[start:end]
|
||||
padded_labels.append(padded_tensor)
|
||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
concatenated_features = []
|
||||
label_positions = []
|
||||
for key in ("chosen_ids", "rejected_ids"):
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||
concatenated_features.append({
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (prompt_len + answer_len)
|
||||
})
|
||||
label_positions.append((prompt_len, answer_len))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
concatenated_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
from collections import defaultdict
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, **kwargs)
|
||||
if ref_model is not None:
|
||||
if hasattr(self, "accelerator"):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
else:
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
|
||||
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||
with unwrapped_model.disable_adapter():
|
||||
all_logits: torch.Tensor = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
else:
|
||||
all_logits: torch.Tensor = model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
)
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
|
@ -0,0 +1,59 @@
|
|||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from copy import deepcopy
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = DPOPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
ref_model=ref_model,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
|
@ -10,7 +10,7 @@ from trl import PPOTrainer
|
|||
from trl.core import LengthSampler
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
|
@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
|||
from transformers import Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -33,16 +33,17 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
self,
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self._remove_log()
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
r"""
|
||||
|
@ -72,14 +73,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.pad_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from trl import PPOConfig
|
||||
from torch.optim import AdamW
|
||||
from typing import Optional, List
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
|
@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
|
@ -24,6 +22,7 @@ def run_ppo(
|
|||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
|
@ -42,8 +41,9 @@ def run_ppo(
|
|||
)
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
total_train_batch_size = \
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
|
@ -56,6 +56,7 @@ def run_ppo(
|
|||
ppo_trainer = PPOPeftTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
|
@ -67,8 +68,10 @@ def run_ppo(
|
|||
lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be after save_model
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
@ -25,10 +24,7 @@ def run_pt(
|
|||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
|
@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
|||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [
|
||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
||||
for key in ("accept_ids", "reject_ids") for feature in features
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
||||
}
|
||||
for key in ("chosen_ids", "rejected_ids") for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
|
|
|
@ -79,7 +79,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous()
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
|
|
|
@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
|
|||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
|
@ -13,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_sft(
|
||||
|
@ -21,6 +21,7 @@ def run_sft(
|
|||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
|
@ -50,13 +51,9 @@ def run_sft(
|
|||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"max_new_tokens": data_args.max_target_length + 1,
|
||||
"temperature": 0.95,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
|
|
@ -1,35 +1,47 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
||||
from llmtuner.tuner.dpo import run_dpo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks + [LogCallback()]
|
||||
|
||||
if general_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,7 +4,7 @@ import threading
|
|||
import time
|
||||
import transformers
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from typing import Generator, List, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
|
|
Loading…
Reference in New Issue