improve KTO impl., replace datasets
This commit is contained in:
parent
33a354548e
commit
c450ee87a3
34
README.md
34
README.md
|
@ -45,7 +45,7 @@ Choose your path:
|
|||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
|
@ -69,14 +69,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Changelog
|
||||
|
||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
||||
|
||||
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
||||
|
||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
||||
|
@ -188,6 +190,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
## Provided Datasets
|
||||
|
@ -208,12 +211,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
|
||||
<details><summary>Supervised fine-tuning datasets</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Identity (en&zh)](data/identity.json)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
|
@ -222,7 +225,6 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
|
@ -235,15 +237,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
|
@ -259,13 +262,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||
|
||||
<details><summary>Preference datasets</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
34
README_zh.md
34
README_zh.md
|
@ -45,7 +45,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
|
@ -69,14 +69,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
## 更新日志
|
||||
|
||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
|
||||
|
||||
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
|
||||
|
||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
@ -188,6 +190,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
## 数据集
|
||||
|
@ -208,12 +211,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
<details><summary>指令微调数据集</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Identity (en&zh)](data/identity.json)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
|
@ -222,7 +225,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
|
@ -235,15 +237,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
|
@ -259,13 +262,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
<details><summary>偏好数据集</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -19,7 +19,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
|
|||
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)"
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
||||
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
||||
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
||||
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
||||
},
|
||||
"tags (optional, used for the sharegpt format)": {
|
||||
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||
|
@ -42,13 +45,13 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
|||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"instruction": "human instruction (required)",
|
||||
"input": "human input (optional)",
|
||||
"output": "model response (required)",
|
||||
"system": "system prompt (optional)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
["human instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["human instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
|
@ -69,7 +72,7 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
|||
}
|
||||
```
|
||||
|
||||
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
The `query` column will be concatenated with the `prompt` column and used as the human prompt, then the human prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
|
||||
|
||||
|
@ -98,12 +101,10 @@ For the **preference datasets**, the `response` column should be a string list w
|
|||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
"instruction": "human instruction",
|
||||
"input": "human input",
|
||||
"chosen": "chosen answer",
|
||||
"rejected": "rejected answer"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
@ -117,7 +118,8 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
|||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
@ -132,7 +134,7 @@ The dataset in **sharegpt** format should follow the below format:
|
|||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "user instruction"
|
||||
"value": "human instruction"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
|
@ -179,7 +181,7 @@ We also supports the dataset in the **openai** format:
|
|||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "user instruction"
|
||||
"content": "human instruction"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
|
|
@ -19,7 +19,10 @@
|
|||
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)"
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)",
|
||||
"chosen": "数据集代表更优回复的表头名称(默认:None)",
|
||||
"rejected": "数据集代表更差回复的表头名称(默认:None)",
|
||||
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
||||
},
|
||||
"tags(可选,用于 sharegpt 格式)": {
|
||||
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||
|
@ -42,8 +45,8 @@
|
|||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"instruction": "人类指令(必填)",
|
||||
"input": "人类输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"system": "系统提示词(选填)",
|
||||
"history": [
|
||||
|
@ -69,7 +72,7 @@
|
|||
}
|
||||
```
|
||||
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为人类指令,即人类指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
|
||||
|
||||
|
@ -98,12 +101,10 @@
|
|||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
"instruction": "人类指令",
|
||||
"input": "人类输入",
|
||||
"chosen": "优质回答",
|
||||
"rejected": "劣质回答"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
@ -117,7 +118,8 @@
|
|||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
@ -132,7 +134,7 @@
|
|||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "用户指令"
|
||||
"value": "人类指令"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
|
@ -165,7 +167,7 @@
|
|||
}
|
||||
```
|
||||
|
||||
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
其中 `messages` 列应当是一个列表,且符合 `人类/模型/人类/模型/人类/模型` 的顺序。
|
||||
|
||||
我们同样支持 **openai** 格式的数据集:
|
||||
|
||||
|
@ -179,7 +181,7 @@
|
|||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "用户指令"
|
||||
"content": "人类指令"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
|
260012
data/alpaca_data_en_52k.json
260012
data/alpaca_data_en_52k.json
File diff suppressed because it is too large
Load Diff
257306
data/alpaca_data_zh_51k.json
257306
data/alpaca_data_zh_51k.json
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
|
@ -1,48 +1,23 @@
|
|||
{
|
||||
"alpaca_en": {
|
||||
"file_name": "alpaca_data_en_52k.json"
|
||||
},
|
||||
"alpaca_zh": {
|
||||
"file_name": "alpaca_data_zh_51k.json"
|
||||
},
|
||||
"alpaca_gpt4_en": {
|
||||
"file_name": "alpaca_gpt4_data_en.json"
|
||||
},
|
||||
"alpaca_gpt4_zh": {
|
||||
"file_name": "alpaca_gpt4_data_zh.json"
|
||||
},
|
||||
"identity": {
|
||||
"file_name": "identity.json"
|
||||
},
|
||||
"oaast_sft_zh": {
|
||||
"file_name": "oaast_sft_zh.json",
|
||||
"alpaca_en_demo": {
|
||||
"file_name": "alpaca_en_demo.json"
|
||||
},
|
||||
"alpaca_zh_demo": {
|
||||
"file_name": "alpaca_zh_demo.json"
|
||||
},
|
||||
"glaive_toolcall_en_demo": {
|
||||
"file_name": "glaive_toolcall_en_demo.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
"messages": "conversations",
|
||||
"tools": "tools"
|
||||
}
|
||||
},
|
||||
"lima": {
|
||||
"file_name": "lima.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"kto-mix-test": {
|
||||
"file_name": "kto-mix-test.json",
|
||||
"file_sha1": "91b59f657007dc4b17529fc643v9b9cd6d640fha",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"response": "output",
|
||||
"tag": "tag"
|
||||
}
|
||||
},
|
||||
"glaive_toolcall": {
|
||||
"file_name": "glaive_toolcall_10k.json",
|
||||
"glaive_toolcall_zh_demo": {
|
||||
"file_name": "glaive_toolcall_zh_demo.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
|
@ -63,15 +38,42 @@
|
|||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"example": {
|
||||
"script_url": "example_dataset",
|
||||
"alpaca_en": {
|
||||
"hf_hub_url": "llamafactory/alpaca_en",
|
||||
"ms_hub_url": "llamafactory/alpaca_en"
|
||||
},
|
||||
"alpaca_zh": {
|
||||
"hf_hub_url": "llamafactory/alpaca_zh",
|
||||
"ms_hub_url": "llamafactory/alpaca_zh"
|
||||
},
|
||||
"alpaca_gpt4_en": {
|
||||
"hf_hub_url": "llamafactory/alpaca_gpt4_en",
|
||||
"ms_hub_url": "llamafactory/alpaca_gpt4_en"
|
||||
},
|
||||
"alpaca_gpt4_zh": {
|
||||
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
|
||||
},
|
||||
"glaive_toolcall_en": {
|
||||
"hf_hub_url": "llamafactory/glaive_toolcall_en",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
"messages": "conversations",
|
||||
"tools": "tools"
|
||||
}
|
||||
},
|
||||
"glaive_toolcall_zh": {
|
||||
"hf_hub_url": "llamafactory/glaive_toolcall_zh",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"tools": "tools"
|
||||
}
|
||||
},
|
||||
"lima": {
|
||||
"hf_hub_url": "llamafactory/lima",
|
||||
"formatting": "sharegpt"
|
||||
},
|
||||
"guanaco": {
|
||||
"hf_hub_url": "JosephusCheung/GuanacoDataset",
|
||||
"ms_hub_url": "AI-ModelScope/GuanacoDataset"
|
||||
|
@ -240,6 +242,12 @@
|
|||
"response": "text"
|
||||
}
|
||||
},
|
||||
"stem_zh": {
|
||||
"hf_hub_url": "hfl/stem_zh_instruction"
|
||||
},
|
||||
"ruozhiba_gpt4": {
|
||||
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
|
||||
},
|
||||
"llava_150k_en": {
|
||||
"hf_hub_url": "BUAADreamer/llava-en-zh-300k",
|
||||
"subset": "en",
|
||||
|
@ -297,73 +305,105 @@
|
|||
"ultrachat_de": {
|
||||
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
|
||||
},
|
||||
"hh_rlhf_en": {
|
||||
"script_url": "hh_rlhf_en",
|
||||
"dpo_en_demo": {
|
||||
"file_name": "dpo_en_demo.json",
|
||||
"ranking": true,
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
"messages": "conversations",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"ranking": true
|
||||
},
|
||||
"oaast_rm_zh": {
|
||||
"file_name": "oaast_rm_zh.json",
|
||||
"dpo_zh_demo": {
|
||||
"file_name": "dpo_zh_demo.json",
|
||||
"ranking": true,
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
"messages": "conversations",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"ranking": true
|
||||
"dpo_mix_en": {
|
||||
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
|
||||
"subset": "en",
|
||||
"ranking": true,
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"comparison_gpt4_en": {
|
||||
"file_name": "comparison_gpt4_data_en.json",
|
||||
"ranking": true
|
||||
"dpo_mix_zh": {
|
||||
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
|
||||
"subset": "zh",
|
||||
"ranking": true,
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
},
|
||||
"comparison_gpt4_zh": {
|
||||
"file_name": "comparison_gpt4_data_zh.json",
|
||||
"ranking": true
|
||||
},
|
||||
"orca_rlhf": {
|
||||
"file_name": "orca_rlhf.json",
|
||||
"orca_pairs": {
|
||||
"hf_hub_url": "Intel/orca_dpo_pairs",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "question",
|
||||
"response": "answer",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected",
|
||||
"system": "system"
|
||||
}
|
||||
},
|
||||
"hh_rlhf_en": {
|
||||
"script_url": "hh_rlhf_en",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"nectar_rm": {
|
||||
"hf_hub_url": "AstraMindAI/RLAIF-Nectar",
|
||||
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
|
||||
"ranking": true
|
||||
},
|
||||
"dpo_mix_en": {
|
||||
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
|
||||
"subset": "en",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "prompt",
|
||||
"response": "answer",
|
||||
"system": "system",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"dpo_mix_zh": {
|
||||
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
|
||||
"subset": "zh",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "prompt",
|
||||
"response": "answer",
|
||||
"system": "system",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"orca_dpo_de": {
|
||||
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
|
||||
"ranking": true
|
||||
},
|
||||
"kto_en_demo": {
|
||||
"file_name": "kto_en_demo.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages",
|
||||
"kto_tag": "label"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"kto_mix_en": {
|
||||
"hf_hub_url": "argilla/kto-mix-15k",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "completion",
|
||||
"kto_tag": "label"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"wiki_demo": {
|
||||
"file_name": "wiki_demo.txt",
|
||||
"columns": {
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,37 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = ""
|
||||
_LICENSE = ""
|
||||
_URL = "examples.json"
|
||||
|
||||
|
||||
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||
for key, example in enumerate(example_dataset):
|
||||
yield key, example
|
|
@ -1,20 +0,0 @@
|
|||
[
|
||||
{
|
||||
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
|
||||
"input": "",
|
||||
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
|
||||
"history": [
|
||||
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
|
||||
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
|
||||
]
|
||||
},
|
||||
{
|
||||
"instruction": "好的,谢谢你!",
|
||||
"input": "",
|
||||
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
|
||||
"history": [
|
||||
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
|
||||
["我在纽约。", "纽约今天晴间多云,气温最高约26摄氏度,最低约18摄氏度,记得注意保暖喔。"]
|
||||
]
|
||||
}
|
||||
]
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
@ -79,5 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|||
break
|
||||
prompt = prompt[:human_idx]
|
||||
|
||||
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
|
||||
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
|
||||
key += 1
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
6417
data/lima.json
6417
data/lima.json
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
102874
data/orca_rlhf.json
102874
data/orca_rlhf.json
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
|||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### KTO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### ORPO Training
|
||||
|
||||
```bash
|
||||
|
|
|
@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
|||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### KTO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### ORPO 训练
|
||||
|
||||
```bash
|
||||
|
|
|
@ -11,7 +11,7 @@ badam_switch_interval: 50
|
|||
badam_verbose: 2
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -12,7 +12,7 @@ lora_target: q_proj,v_proj
|
|||
ddp_timeout: 180000000
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -12,7 +12,7 @@ galore_rank: 128
|
|||
galore_scale: 2.0
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -10,7 +10,7 @@ freeze_trainable_modules: all
|
|||
use_llama_pro: true
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
|
|||
loraplus_lr_ratio: 16.0
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: full
|
|||
mixture_of_depths: convert
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -7,7 +7,7 @@ do_predict: true
|
|||
finetuning_type: full
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
|
|
|
@ -11,7 +11,7 @@ ddp_timeout: 180000000
|
|||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -11,7 +11,7 @@ lora_target: q_proj,v_proj
|
|||
ddp_timeout: 180000000
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -12,7 +12,7 @@ ddp_timeout: 180000000
|
|||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -12,7 +12,7 @@ ddp_timeout: 180000000
|
|||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
|
|||
dpo_ftx: 1.0
|
||||
|
||||
### dataset
|
||||
dataset: orca_rlhf
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -26,7 +26,7 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 0.000005
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
### method
|
||||
stage: kto
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
kto_ftx: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/kto
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.000005
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
fp16: true
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 500
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: orca_rlhf
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -25,7 +25,7 @@ overwrite_output_dir: true
|
|||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 0.000005
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
|
|
|
@ -9,7 +9,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ do_predict: true
|
|||
finetuning_type: lora
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: orca_rlhf
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -9,7 +9,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
|||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|||
|
||||
from datasets import Features
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
|
||||
|
||||
|
@ -14,7 +15,13 @@ if TYPE_CHECKING:
|
|||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
|
@ -29,7 +36,10 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
|
|||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
|
@ -45,23 +55,33 @@ def convert_alpaca(
|
|||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
if examples[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -69,6 +89,9 @@ def convert_alpaca(
|
|||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
|
@ -88,21 +111,62 @@ def convert_sharegpt(
|
|||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
||||
): # pairwise example
|
||||
chosen = examples[dataset_attr.chosen][i]
|
||||
rejected = examples[dataset_attr.rejected][i]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
continue
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
@ -138,7 +202,6 @@ def align_dataset(
|
|||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
"tag": {"dtype": "bool", "_type": "Value"},
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
|
|
|
@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
def __call__(self, features, return_tensors=None):
|
||||
concatenated_features = []
|
||||
kl_concatenated_features = []
|
||||
tags = []
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
concatenated_features.append(
|
||||
target_features.append(
|
||||
{
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
)
|
||||
kl_concatenated_features.append(
|
||||
kl_features.append(
|
||||
{
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
)
|
||||
tags.append(feature["tag"])
|
||||
batch = super().__call__(concatenated_features)
|
||||
kl_batch = super().__call__(kl_concatenated_features)
|
||||
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
|
||||
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
kl_batch = super().__call__(kl_features)
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
batch["tag"] = torch.tensor(tags)
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
|
@ -57,7 +57,7 @@ def load_single_dataset(
|
|||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
|
@ -116,7 +116,7 @@ def get_dataset(
|
|||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
|
|
|
@ -25,21 +25,22 @@ class DatasetAttr:
|
|||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
tag: Optional[bool] = None
|
||||
""" columns for the alpaca format """
|
||||
""" rlhf columns """
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
chosen: Optional[str] = "chosen"
|
||||
rejected: Optional[str] = "rejected"
|
||||
history: Optional[str] = None
|
||||
""" columns for the sharegpt format """
|
||||
""" sharegpt columns """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
""" tags for the sharegpt format """
|
||||
""" sharegpt tags """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
|
@ -107,11 +108,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "images", "tag"]
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
|
|
@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
|
|||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
|
|||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["tag"].append(examples["tag"])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
|
||||
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
|
||||
examples['kl_response'] = examples['response'][::-1]
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
|
||||
input_ids, labels = [], []
|
||||
kl_input_ids, kl_labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
kl_input_ids += source_ids + target_ids
|
||||
kl_labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
kl_input_ids += [tokenizer.eos_token_id]
|
||||
kl_labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["tag"].append(examples["tag"][i])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
|
||||
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
|
||||
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
|
||||
if desirable == 0 or undesirable == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
|
@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
|
|||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
|||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
|
|
@ -137,21 +137,21 @@ class RLHFArguments:
|
|||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the KTO loss."},
|
||||
)
|
||||
kto_chosen_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the desirable losses in KTO training."},
|
||||
)
|
||||
kto_rejected_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
|
||||
)
|
||||
kto_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
|
||||
)
|
||||
kto_desirable_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The desirable weight for the KTO loss."},
|
||||
)
|
||||
kto_undesirable_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The undesirable weight for the KTO loss."},
|
||||
)
|
||||
orpo_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
||||
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
|
||||
)
|
||||
ppo_buffer_size: int = field(
|
||||
default=1,
|
||||
|
@ -307,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
|
|
|
@ -47,11 +47,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
self.ref_model = ref_model
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# dpo hyperparams
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||
self.loss_type = finetuning_args.dpo_loss
|
||||
self.ftx_gamma = finetuning_args.dpo_ftx
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
|
@ -143,6 +145,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
@ -13,7 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import PreTrainedModel, ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
|
@ -24,6 +24,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -33,6 +34,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
|
@ -43,15 +45,15 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
self._precomputed_train_ref_log_probs = False
|
||||
self._precomputed_eval_ref_log_probs = False
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
self.ref_model = ref_model
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# KTO parameter
|
||||
# kto hyperparams
|
||||
self.beta = finetuning_args.kto_beta
|
||||
self.desirable_weight = finetuning_args.kto_chosen_weight
|
||||
self.undesirable_weight = finetuning_args.kto_rejected_weight
|
||||
self.ftx_gamma = finetuning_args.kto_ftx
|
||||
self.desirable_weight = finetuning_args.kto_desirable_weight
|
||||
self.undesirable_weight = finetuning_args.kto_undesirable_weight
|
||||
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
|
@ -82,78 +84,85 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||
return -all_logps.nanmean()
|
||||
|
||||
return -all_logps
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
KL_logits = model(
|
||||
batch["KL_completion_input_ids"],
|
||||
attention_mask=batch["KL_completion_attention_mask"],
|
||||
).logits
|
||||
kl_logits = model(
|
||||
input_ids=batch["kl_input_ids"],
|
||||
attention_mask=batch["kl_attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
|
||||
completion_logits = model(
|
||||
batch["input_ids"],
|
||||
target_logits = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
).logits
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
|
||||
completion_logps = self.get_batch_logps(
|
||||
completion_logits,
|
||||
batch["labels"],
|
||||
target_logps = self.get_batch_logps(
|
||||
logits=target_logits,
|
||||
labels=batch["labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
KL_logps = self.get_batch_logps(
|
||||
KL_logits,
|
||||
batch["kl_labels"],
|
||||
kl_logps = self.get_batch_logps(
|
||||
logits=kl_logits,
|
||||
labels=batch["kl_labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
if completion_logps.shape[0] != len(batch["tag"]):
|
||||
raise ValueError(
|
||||
"There is a mismatch between the number of examples in this batch and the number of "
|
||||
"examples for which an output sequence was predicted."
|
||||
)
|
||||
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
|
||||
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
|
||||
if len(target_logps) != len(batch["kto_tags"]):
|
||||
raise ValueError("Mismatched shape of inputs and labels.")
|
||||
|
||||
chosen_logps = completion_logps[chosen_idx, ...]
|
||||
rejected_logps = completion_logps[rejected_idx, ...]
|
||||
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
|
||||
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
|
||||
|
||||
chosen_logits = completion_logits[chosen_idx, ...]
|
||||
rejected_logits = completion_logits[rejected_idx, ...]
|
||||
chosen_logps = target_logps[chosen_idx, ...]
|
||||
rejected_logps = target_logps[rejected_idx, ...]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
||||
chosen_logits = target_logits[chosen_idx, ...]
|
||||
rejected_logits = target_logits[rejected_idx, ...]
|
||||
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
):
|
||||
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_KL_logps,
|
||||
_,
|
||||
policy_kl_logps,
|
||||
) = self.forward(model, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -163,27 +172,29 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
with ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
reference_KL_logps,
|
||||
reference_kl_logps,
|
||||
) = self.forward(ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_KL_logps,
|
||||
policy_kl_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
reference_KL_logps,
|
||||
reference_kl_logps,
|
||||
)
|
||||
losses = losses.nanmean()
|
||||
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
|
||||
|
||||
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
|
||||
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
|
||||
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
|
||||
|
||||
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
||||
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
||||
|
|
|
@ -48,9 +48,9 @@ def run_kto(
|
|||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ def run_ppo(
|
|||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
|
|
|
@ -9,12 +9,13 @@ from ..extras.logging import get_logger
|
|||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .orpo import run_orpo
|
||||
from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .kto import run_kto
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
@ -37,10 +38,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
|||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "orpo":
|
||||
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "kto":
|
||||
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "orpo":
|
||||
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue