improve KTO impl., replace datasets

This commit is contained in:
hiyouga 2024-05-18 03:44:56 +08:00
parent 33a354548e
commit c450ee87a3
65 changed files with 46415 additions and 2035053 deletions

View File

@ -45,7 +45,7 @@ Choose your path:
## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
@ -69,14 +69,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary>
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
@ -188,6 +190,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## Provided Datasets
@ -208,12 +211,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
<details><summary>Supervised fine-tuning datasets</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -222,7 +225,6 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@ -235,15 +237,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@ -259,13 +262,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
<details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details>

View File

@ -45,7 +45,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
@ -69,14 +69,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary>
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
@ -188,6 +190,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## 数据集
@ -208,12 +211,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -222,7 +225,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@ -235,15 +237,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@ -259,13 +262,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details>

View File

@ -19,7 +19,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
"images": "the column name in the dataset containing the image inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
@ -42,13 +45,13 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
```json
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
["human instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
@ -69,7 +72,7 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
}
```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `query` column will be concatenated with the `prompt` column and used as the human prompt, then the human prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
@ -98,12 +101,10 @@ For the **preference datasets**, the `response` column should be a string list w
```json
[
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
"instruction": "human instruction",
"input": "human input",
"chosen": "chosen answer",
"rejected": "rejected answer"
}
]
```
@ -117,7 +118,8 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@ -132,7 +134,7 @@ The dataset in **sharegpt** format should follow the below format:
"conversations": [
{
"from": "human",
"value": "user instruction"
"value": "human instruction"
},
{
"from": "gpt",
@ -179,7 +181,7 @@ We also supports the dataset in the **openai** format:
},
{
"role": "user",
"content": "user instruction"
"content": "human instruction"
},
{
"role": "assistant",

View File

@ -19,7 +19,10 @@
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
"images": "数据集代表图像输入的表头名称默认None",
"chosen": "数据集代表更优回复的表头名称默认None",
"rejected": "数据集代表更差回复的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
@ -42,8 +45,8 @@
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
@ -69,7 +72,7 @@
}
```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为人类指令,即人类指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
@ -98,12 +101,10 @@
```json
[
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"优质回答",
"劣质回答"
]
"instruction": "人类指令",
"input": "人类输入",
"chosen": "优质回答",
"rejected": "劣质回答"
}
]
```
@ -117,7 +118,8 @@
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@ -132,7 +134,7 @@
"conversations": [
{
"from": "human",
"value": "用户指令"
"value": "人类指令"
},
{
"from": "gpt",
@ -165,7 +167,7 @@
}
```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
其中 `messages` 列应当是一个列表,且符合 `人类/模型/人类/模型/人类/模型` 的顺序。
我们同样支持 **openai** 格式的数据集:
@ -179,7 +181,7 @@
},
{
"role": "user",
"content": "用户指令"
"content": "人类指令"
},
{
"role": "assistant",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

5002
data/alpaca_en_demo.json Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

5002
data/alpaca_zh_demo.json Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -1,48 +1,23 @@
{
"alpaca_en": {
"file_name": "alpaca_data_en_52k.json"
},
"alpaca_zh": {
"file_name": "alpaca_data_zh_51k.json"
},
"alpaca_gpt4_en": {
"file_name": "alpaca_gpt4_data_en.json"
},
"alpaca_gpt4_zh": {
"file_name": "alpaca_gpt4_data_zh.json"
},
"identity": {
"file_name": "identity.json"
},
"oaast_sft_zh": {
"file_name": "oaast_sft_zh.json",
"alpaca_en_demo": {
"file_name": "alpaca_en_demo.json"
},
"alpaca_zh_demo": {
"file_name": "alpaca_zh_demo.json"
},
"glaive_toolcall_en_demo": {
"file_name": "glaive_toolcall_en_demo.json",
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"tools": "tools"
}
},
"lima": {
"file_name": "lima.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"kto-mix-test": {
"file_name": "kto-mix-test.json",
"file_sha1": "91b59f657007dc4b17529fc643v9b9cd6d640fha",
"columns": {
"prompt": "instruction",
"response": "output",
"tag": "tag"
}
},
"glaive_toolcall": {
"file_name": "glaive_toolcall_10k.json",
"glaive_toolcall_zh_demo": {
"file_name": "glaive_toolcall_zh_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
@ -63,15 +38,42 @@
"assistant_tag": "assistant"
}
},
"example": {
"script_url": "example_dataset",
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
},
"alpaca_zh": {
"hf_hub_url": "llamafactory/alpaca_zh",
"ms_hub_url": "llamafactory/alpaca_zh"
},
"alpaca_gpt4_en": {
"hf_hub_url": "llamafactory/alpaca_gpt4_en",
"ms_hub_url": "llamafactory/alpaca_gpt4_en"
},
"alpaca_gpt4_zh": {
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
},
"glaive_toolcall_en": {
"hf_hub_url": "llamafactory/glaive_toolcall_en",
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"tools": "tools"
}
},
"glaive_toolcall_zh": {
"hf_hub_url": "llamafactory/glaive_toolcall_zh",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"tools": "tools"
}
},
"lima": {
"hf_hub_url": "llamafactory/lima",
"formatting": "sharegpt"
},
"guanaco": {
"hf_hub_url": "JosephusCheung/GuanacoDataset",
"ms_hub_url": "AI-ModelScope/GuanacoDataset"
@ -240,6 +242,12 @@
"response": "text"
}
},
"stem_zh": {
"hf_hub_url": "hfl/stem_zh_instruction"
},
"ruozhiba_gpt4": {
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
},
"llava_150k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-300k",
"subset": "en",
@ -297,73 +305,105 @@
"ultrachat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"dpo_en_demo": {
"file_name": "dpo_en_demo.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"response": "output",
"history": "history"
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"ranking": true
},
"oaast_rm_zh": {
"file_name": "oaast_rm_zh.json",
"dpo_zh_demo": {
"file_name": "dpo_zh_demo.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"ranking": true
"dpo_mix_en": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "en",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json",
"ranking": true
"dpo_mix_zh": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "zh",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json",
"ranking": true
},
"orca_rlhf": {
"file_name": "orca_rlhf.json",
"orca_pairs": {
"hf_hub_url": "Intel/orca_dpo_pairs",
"ranking": true,
"columns": {
"prompt": "question",
"response": "answer",
"chosen": "chosen",
"rejected": "rejected",
"system": "system"
}
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"ranking": true,
"columns": {
"prompt": "instruction",
"chosen": "chosen",
"rejected": "rejected",
"history": "history"
}
},
"nectar_rm": {
"hf_hub_url": "AstraMindAI/RLAIF-Nectar",
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
"ranking": true
},
"dpo_mix_en": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "en",
"ranking": true,
"columns": {
"prompt": "prompt",
"response": "answer",
"system": "system",
"history": "history"
}
},
"dpo_mix_zh": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "zh",
"ranking": true,
"columns": {
"prompt": "prompt",
"response": "answer",
"system": "system",
"history": "history"
}
},
"orca_dpo_de": {
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
"ranking": true
},
"kto_en_demo": {
"file_name": "kto_en_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"kto_tag": "label"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"kto_mix_en": {
"hf_hub_url": "argilla/kto-mix-15k",
"formatting": "sharegpt",
"columns": {
"messages": "completion",
"kto_tag": "label"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"wiki_demo": {
"file_name": "wiki_demo.txt",
"columns": {

7226
data/dpo_en_demo.json Normal file

File diff suppressed because one or more lines are too long

5058
data/dpo_zh_demo.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,37 +0,0 @@
import json
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@ -1,20 +0,0 @@
[
{
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
"input": "",
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
"history": [
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
]
},
{
"instruction": "好的,谢谢你!",
"input": "",
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
"history": [
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
["我在纽约。", "纽约今天晴间多云气温最高约26摄氏度最低约18摄氏度记得注意保暖喔。"]
]
}
]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -79,5 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
break
prompt = prompt[:human_idx]
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
key += 1

File diff suppressed because one or more lines are too long

5398
data/kto_en_demo.json Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
```
#### KTO Training
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
```
#### ORPO Training
```bash

View File

@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
```
#### KTO 训练
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
```
#### ORPO 训练
```bash

View File

@ -11,7 +11,7 @@ badam_switch_interval: 50
badam_verbose: 2
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ lora_target: q_proj,v_proj
ddp_timeout: 180000000
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ galore_rank: 128
galore_scale: 2.0
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -10,7 +10,7 @@ freeze_trainable_modules: all
use_llama_pro: true
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
loraplus_lr_ratio: 16.0
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: full
mixture_of_depths: convert
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -7,7 +7,7 @@ do_predict: true
finetuning_type: full
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 50

View File

@ -11,7 +11,7 @@ ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -11,7 +11,7 @@ lora_target: q_proj,v_proj
ddp_timeout: 180000000
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
dpo_ftx: 1.0
### dataset
dataset: orca_rlhf
dataset: dpo_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
@ -26,7 +26,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.00001
learning_rate: 0.000005
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1

View File

@ -0,0 +1,39 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
### method
stage: kto
do_train: true
finetuning_type: lora
lora_target: q_proj,v_proj
kto_ftx: 0.1
### dataset
dataset: kto_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b/lora/kto
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.000005
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1
fp16: true
### eval
val_size: 0.1
per_device_eval_batch_size: 1
evaluation_strategy: steps
eval_steps: 500

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: orca_rlhf
dataset: dpo_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
@ -25,7 +25,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.00001
learning_rate: 0.000005
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1

View File

@ -9,7 +9,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ do_predict: true
finetuning_type: lora
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 50

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: orca_rlhf
dataset: dpo_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -9,7 +9,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -1,12 +1,12 @@
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = [
"PairwiseDataCollatorWithPadding",
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from ..extras.logging import get_logger
from .utils import Role
@ -14,7 +15,13 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
@ -29,7 +36,10 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
@ -45,23 +55,33 @@ def convert_alpaca(
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
return outputs
@ -69,6 +89,9 @@ def convert_alpaca(
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
@ -88,21 +111,62 @@ def convert_sharegpt(
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
@ -138,7 +202,6 @@ def align_dataset(
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
"tag": {"dtype": "bool", "_type": "Value"},
}
)
kwargs = {}

View File

@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features, return_tensors=None):
concatenated_features = []
kl_concatenated_features = []
tags = []
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
concatenated_features.append(
target_features.append(
{
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
)
kl_concatenated_features.append(
kl_features.append(
{
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
)
tags.append(feature["tag"])
batch = super().__call__(concatenated_features)
kl_batch = super().__call__(kl_concatenated_features)
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
batch["tag"] = torch.tensor(tags)
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@ -57,7 +57,7 @@ def load_single_dataset(
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
raise ValueError("File {} not found.".format(local_path))
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
@ -116,7 +116,7 @@ def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:

View File

@ -25,21 +25,22 @@ class DatasetAttr:
folder: Optional[str] = None
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
""" common columns """
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
tag: Optional[bool] = None
""" columns for the alpaca format """
""" rlhf columns """
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
chosen: Optional[str] = "chosen"
rejected: Optional[str] = "rejected"
history: Optional[str] = None
""" columns for the sharegpt format """
""" sharegpt columns """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
""" sharegpt tags """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
@ -107,11 +108,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system", "images", "tag"]
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])

View File

@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["tag"].append(examples["tag"])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
examples['kl_response'] = examples['response'][::-1]
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
input_ids, labels = [], []
kl_input_ids, kl_labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
kl_input_ids += source_ids + target_ids
kl_labels += source_mask + target_ids
if template.efficient_eos:
kl_input_ids += [tokenizer.eos_token_id]
kl_labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["tag"].append(examples["tag"][i])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
if desirable == 0 or undesirable == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
messages = examples["prompt"][i] + [examples["response"][i][0]]
else: # undesired example
kto_tag = False
messages = examples["prompt"][i] + [examples["response"][i][1]]
if kl_response[i][0]["content"]:
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
else:
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, kl_response_ids = template.encode_oneturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],

View File

@ -137,21 +137,21 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the KTO loss."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
kto_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
)
kto_desirable_weight: float = field(
default=1.0,
metadata={"help": "The desirable weight for the KTO loss."},
)
kto_undesirable_weight: float = field(
default=1.0,
metadata={"help": "The undesirable weight for the KTO loss."},
)
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field(
default=1,
@ -307,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)

View File

@ -47,11 +47,13 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# dpo hyperparams
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@ -143,6 +145,7 @@ class CustomDPOTrainer(DPOTrainer):
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model

View File

@ -1,7 +1,7 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import torch
from transformers import Trainer
@ -13,7 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
@ -24,6 +24,7 @@ class CustomKTOTrainer(KTOTrainer):
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
disable_dropout: bool = True,
**kwargs,
):
@ -33,6 +34,7 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@ -43,15 +45,15 @@ class CustomKTOTrainer(KTOTrainer):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# KTO parameter
# kto hyperparams
self.beta = finetuning_args.kto_beta
self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.kto_ftx
self.desirable_weight = finetuning_args.kto_desirable_weight
self.undesirable_weight = finetuning_args.kto_undesirable_weight
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@ -82,78 +84,85 @@ class CustomKTOTrainer(KTOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps.nanmean()
return -all_logps
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
with torch.no_grad():
KL_logits = model(
batch["KL_completion_input_ids"],
attention_mask=batch["KL_completion_attention_mask"],
).logits
kl_logits = model(
input_ids=batch["kl_input_ids"],
attention_mask=batch["kl_attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
completion_logits = model(
batch["input_ids"],
target_logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
completion_logps = self.get_batch_logps(
completion_logits,
batch["labels"],
target_logps = self.get_batch_logps(
logits=target_logits,
labels=batch["labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
KL_logps = self.get_batch_logps(
KL_logits,
batch["kl_labels"],
kl_logps = self.get_batch_logps(
logits=kl_logits,
labels=batch["kl_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if completion_logps.shape[0] != len(batch["tag"]):
raise ValueError(
"There is a mismatch between the number of examples in this batch and the number of "
"examples for which an output sequence was predicted."
)
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
chosen_logps = completion_logps[chosen_idx, ...]
rejected_logps = completion_logps[rejected_idx, ...]
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
chosen_logits = completion_logits[chosen_idx, ...]
rejected_logits = completion_logits[rejected_idx, ...]
chosen_logps = target_logps[chosen_idx, ...]
rejected_logps = target_logps[rejected_idx, ...]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
chosen_logits = target_logits[chosen_idx, ...]
rejected_logits = target_logits[rejected_idx, ...]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
):
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_KL_logps,
_,
policy_kl_logps,
) = self.forward(model, batch)
with torch.no_grad():
@ -163,27 +172,29 @@ class CustomKTOTrainer(KTOTrainer):
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
reference_kl_logps,
) = self.forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
policy_kl_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_KL_logps,
reference_kl_logps,
)
losses = losses.nanmean()
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

View File

@ -48,9 +48,9 @@ def run_kto(
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)

View File

@ -29,7 +29,7 @@ def run_ppo(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@ -9,12 +9,13 @@ from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .dpo import run_dpo
from .kto import run_kto
from .orpo import run_orpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .kto import run_kto
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -37,10 +38,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")