Initial commit

This commit is contained in:
hiyouga 2023-05-28 18:09:04 +08:00
commit 769c6ab56b
31 changed files with 1606994 additions and 0 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

29
README.md Normal file
View File

@ -0,0 +1,29 @@
# LLaMA Efficient Tuning
1. Download the weights of the LLaMA models.
2. Convert them to HF format using this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)
```python
python convert_llama_weights_to_hf.py \
--input_dir path_to_llama_weights --model_size 7B --output_dir llama_7b
```
3. Fine-tune the LLaMA models.
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--model_name_or_path llama_7b \
--do_train \
--dataset alpaca_gpt4_zh \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 2 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 100 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--fp16
```

53
data/README.md Normal file
View File

@ -0,0 +1,53 @@
Data format in `dataset_info.json`:
```json
"dataset_name": {
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
}
}
```
`dataset_info.json` 中的数据集定义格式:
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
}
}
```
部分预置数据集简介:
| 数据集名称 | 规模 | 描述 |
| --- | --- | --- |
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly流萤的中文数据集包含多个 NLP 任务 |
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。

260012
data/alpaca_data_en_52k.json Normal file

File diff suppressed because it is too large Load Diff

257308
data/alpaca_data_zh_51k.json Normal file

File diff suppressed because it is too large Load Diff

260012
data/alpaca_gpt4_data_en.json Normal file

File diff suppressed because it is too large Load Diff

244092
data/alpaca_gpt4_data_zh.json Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

97
data/dataset_info.json Normal file
View File

@ -0,0 +1,97 @@
{
"alpaca_en": {
"hf_hub_url": "tatsu-lab/alpaca"
},
"alpaca_zh": {
"file_name": "alpaca_data_zh_51k.json",
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
},
"alpaca_gpt4_en": {
"file_name": "alpaca_gpt4_data_en.json",
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
},
"alpaca_gpt4_zh": {
"file_name": "alpaca_gpt4_data_zh.json",
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
},
"belle_0.5m": {
"hf_hub_url": "BelleGroup/train_0.5M_CN"
},
"belle_1m": {
"hf_hub_url": "BelleGroup/train_1M_CN"
},
"belle_2m": {
"hf_hub_url": "BelleGroup/train_2M_CN"
},
"belle_dialog": {
"hf_hub_url": "BelleGroup/generated_chat_0.4M"
},
"belle_math": {
"hf_hub_url": "BelleGroup/school_math_0.25M"
},
"belle_multiturn": {
"hf_hub_url": "BelleGroup/multiturn_chat_0.8M"
},
"guanaco": {
"hf_hub_url": "JosephusCheung/GuanacoDataset"
},
"firefly": {
"hf_hub_url": "YeungNLP/firefly-train-1.1M",
"columns": {
"prompt": "input",
"query": "",
"response": "target",
"history": ""
}
},
"codealpaca": {
"hf_hub_url": "sahil2801/CodeAlpaca-20k"
},
"alpaca_cot": {
"hf_hub_url": "QingyiSi/Alpaca-CoT"
},
"webqa": {
"hf_hub_url": "suolyer/webqa",
"columns": {
"prompt": "input",
"query": "",
"response": "output",
"history": ""
}
},
"ultra_chat": {
"script_url": "ultra_chat",
"columns": {
"prompt": "instruction",
"query": "",
"response": "output",
"history": "history"
}
},
"example": {
"script_url": "example_dataset",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json",
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
},
"comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json",
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"columns": {
"prompt": "instruction",
"query": "",
"response": "output",
"history": "history"
}
}
}

View File

@ -0,0 +1,46 @@
import json
import datasets
from typing import Any, Dict, List
_DESCRIPTION = "An example of dataset for LLaMA."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": file_path
}
)
]
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@ -0,0 +1,20 @@
[
{
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
"input": "",
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
"history": [
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
]
},
{
"instruction": "好的,谢谢你!",
"input": "",
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
"history": [
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
["我在纽约。", "纽约今天晴间多云气温最高约26摄氏度最低约18摄氏度记得注意保暖喔。"]
]
}
]

View File

@ -0,0 +1,97 @@
import json
import datasets
from typing import Any, Dict, List
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
_CITATION = ""
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit"
_URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
_URLS = {
"train": [
_URL + "harmless-base/train.jsonl.gz",
_URL + "helpful-base/train.jsonl.gz",
_URL + "helpful-online/train.jsonl.gz",
_URL + "helpful-rejection-sampled/train.jsonl.gz"
],
"test": [
_URL + "harmless-base/test.jsonl.gz",
_URL + "helpful-base/test.jsonl.gz",
_URL + "helpful-online/test.jsonl.gz",
_URL + "helpful-rejection-sampled/test.jsonl.gz"
]
}
class HhRlhfEn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download_and_extract(_URLS)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_path["train"]
}
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepaths": file_path["test"]
}
)
]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
key = 0
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
for row in f:
data = json.loads(row)
chosen = data["chosen"]
rejected = data["rejected"]
assist_idx = rejected.rfind("\n\nAssistant: ")
r_reject = rejected[assist_idx+13:].strip()
assist_idx = chosen.rfind("\n\nAssistant: ")
r_accept = chosen[assist_idx+13:].strip()
human_idx = chosen.rfind("\n\nHuman: ")
query = chosen[human_idx+9:assist_idx].strip()
prompt = chosen[:human_idx]
history = []
while prompt.rfind("\n\nAssistant: ") != -1:
assist_idx = prompt.rfind("\n\nAssistant: ")
human_idx = prompt.rfind("\n\nHuman: ")
if human_idx != -1:
old_query = prompt[human_idx+9:assist_idx].strip()
old_resp = prompt[assist_idx+13:].strip()
history.insert(0, (old_query, old_resp))
else:
break
prompt = prompt[:human_idx]
yield key, {
"instruction": query,
"output": [r_accept, r_reject],
"history": history
}
key += 1

View File

@ -0,0 +1,76 @@
import json
import datasets
from typing import Any, Dict, List
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
_CITATION = """\
@misc{UltraChat,
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\\url{https://github.com/thunlp/ultrachat}},
}
"""
_HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
_LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepaths": file_paths
}
)
]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
for row in f:
try:
data = json.loads(row)
except:
continue
key = data["id"]
content = data["data"]
if len(content) % 2 == 1:
content.pop(-1)
if len(content) < 2:
continue
query = content[-2]
response = content[-1]
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
yield key, {
"instruction": query,
"output": response,
"history": history
}

0
src/__init__.py Normal file
View File

66
src/cli_demo.py Normal file
View File

@ -0,0 +1,66 @@
# coding=utf-8
# Implements stream chat in command line for LLaMA fine-tuned with PEFT.
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
import torch
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
def main():
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
def predict(query, history: list):
inputs = tokenizer([query], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = {
"do_sample": True,
"top_p": 0.9,
"top_k": 40,
"temperature": 0.7,
"num_beams": 1,
"max_new_tokens": 256,
"repetition_penalty": 1.5
}
with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(query, response)]
return response, history
history = []
print("欢迎使用 LLaMA-7B 模型输入内容即可对话clear清空对话历史stop终止程序")
while True:
try:
query = input("\nInput: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
continue
response, history = predict(query, history)
print("LLaMA-7B:", response)
if __name__ == "__main__":
main()

23
src/export_model.py Normal file
View File

@ -0,0 +1,23 @@
# coding=utf-8
# Exports the fine-tuned LLaMA model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from transformers import HfArgumentParser, TrainingArguments
from utils import ModelArguments, load_pretrained
def main():
parser = HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
if __name__ == "__main__":
main()

80
src/train_ppo.py Normal file
View File

@ -0,0 +1,80 @@
# coding=utf-8
# Implements parameter-efficient PPO training of fine-tuned LLaMA.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math
from torch.optim import AdamW
from transformers.optimization import get_scheduler
from trl import PPOConfig
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
DataCollatorForLLaMA,
PPOTrainerForLLaMA,
plot_loss
)
def main():
# prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
total_train_batch_size = \
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
)
# Initialize our Trainer
ppo_trainer = PPOTrainerForLLaMA(
training_args=training_args,
finetuning_args=finetuning_args,
config=ppo_config,
model=model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"])
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

72
src/train_rm.py Normal file
View File

@ -0,0 +1,72 @@
# coding=utf-8
# Implements parameter-efficient training of a reward model based on LLaMA.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
PairwiseDataCollatorForLLaMA,
PairwiseTrainerForLLaMA,
plot_loss
)
def main():
# prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
training_args.remove_unused_columns = False # Important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PairwiseTrainerForLLaMA(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
**trainer_kwargs
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

95
src/train_sft.py Normal file
View File

@ -0,0 +1,95 @@
# coding=utf-8
# Implements several parameter-efficient supervised fine-tuning method for LLaMA.
# This code is inspired by
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from utils import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DataCollatorForLLaMA,
Seq2SeqTrainerForLLaMA,
ComputeMetrics,
get_logits_processor,
plot_loss
)
def main():
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = Seq2SeqTrainerForLLaMA(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
)
# Keyword arguments for `model.generate`
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results, tokenizer)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

15
src/utils/__init__.py Normal file
View File

@ -0,0 +1,15 @@
from .common import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data
)
from .data_collator import DataCollatorForLLaMA
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
from .ppo import PPOTrainerForLLaMA
from .config import ModelArguments
from .other import auto_configure_device_map, get_logits_processor, plot_loss

459
src/utils/common.py Normal file
View File

@ -0,0 +1,459 @@
import os
import sys
import torch
import hashlib
from typing import List, Literal, Optional, Tuple
import transformers
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
import datasets
from datasets import Dataset, concatenate_datasets, load_dataset
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from trl import AutoModelForCausalLMWithValueHead
from .config import (
ModelArguments,
DataTrainingArguments,
FinetuningArguments
)
from .other import (
get_logger,
load_trainable_params,
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX,
FINETUNING_ARGS_NAME
)
check_min_version("4.29.1")
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
logger = get_logger(__name__)
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: bool
) -> PreTrainedModel:
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
lastest_checkpoint = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training: # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if lastest_checkpoint is not None: # resume lora training
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True)
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
model = get_peft_model(model, lora_config)
return model
def load_pretrained(
model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
if (not is_trainable) and (model_args.checkpoint_dir is None):
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint
for checkpoint_dir in model_args.checkpoint_dir:
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)):
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
tokenizer = LlamaTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
)
tokenizer.pad_token_id = 0 # set as the <unk> token
# Quantization configurations (using bitsandbytes library).
config_kwargs = {}
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization."
config_kwargs["load_in_8bit"] = True
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() # cast all params to float16 for inference
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model)
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
# To meet the compliance requirements of the transformers library
if model_args.quantization_bit is not None:
model._is_int8_training_enabled = True
print_trainable_params(model)
return model, tokenizer
def prepare_args(
stage: Literal["sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` for saving model predictions.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
training_args.ddp_find_unused_parameters = False
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args
def prepare_data(
model_args: ModelArguments,
data_args: DataTrainingArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
elif dataset_attr.load_from == "script":
raw_datasets = load_dataset(
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
cache_dir=model_args.cache_dir
)
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv
extension = dataset_attr.file_name.split(".")[-1]
if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset(
extension,
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
)
else:
raise NotImplementedError
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
def preprocess_data(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataTrainingArguments,
training_args: Seq2SeqTrainingArguments,
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
) -> Dataset:
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
def format_example(examples): # support question with a single answer or multiple answers
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += examples["query"][i]
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n" + prefix
if examples["history"][i]:
history = examples["history"][i]
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant: ".format(query)
yield prompt, answer
def preprocess_supervised_dataset(examples):
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_evaluation_dataset(examples):
# build inputs with format `X <s>` and labels with format `Y <s>`
model_inputs = {"input_ids": [], "labels": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # bos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id]
labels = target_ids + [tokenizer.bos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `X <s> Y1 </s>` and `X <s> Y2 </s>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for prompt, answer in format_example(examples):
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs
def print_sft_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
def print_ppo_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
if stage == "sft":
if (not training_args.do_train) and training_args.predict_with_generate: # with generation
preprocess_function = preprocess_evaluation_dataset
else: # without generation
preprocess_function = preprocess_supervised_dataset
elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
preprocess_function = preprocess_evaluation_dataset
with training_args.main_process_first(desc="dataset map pre-processing"):
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
if stage == "sft":
print_sft_dataset_example(dataset[0])
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
elif stage == "ppo":
print_ppo_dataset_example(dataset[0])
return dataset

212
src/utils/config.py Normal file
View File

@ -0,0 +1,212 @@
import os
import json
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
file_sha1: Optional[str] = None
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
dataset: Optional[str] = field(
default="alpaca_zh",
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."}
)
max_source_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total input sequence length after tokenization."}
)
max_target_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total output sequence length after tokenization."}
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
)
num_beams: Optional[int] = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
)
ignore_pad_token_for_loss: Optional[bool] = field(
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
source_prefix: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
dev_ratio: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)
def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r"))
self.dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
)
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
@dataclass
class FinetuningArguments:
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
)
lora_dropout: Optional[float] = field(
default=0.1,
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
)
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if isinstance(self.lora_target, str):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
if self.name_module_trainable == "mlp":
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
elif self.name_module_trainable == "qkv":
self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Create an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@ -0,0 +1,67 @@
import torch
from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from .other import IGNORE_INDEX
class DataCollatorForLLaMA(DataCollatorWithPadding):
r"""
Data collator for LLaMA. It is capable of dynamically padding for batched data.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
ignore_pad_token_for_loss: Optional[bool] = False
):
super().__init__(tokenizer, padding=True)
self.model = model
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates attention masks for left-padded sequences.
"""
batch_size, seq_length = input_ids.size()
attention_mask = torch.ones((batch_size, seq_length), device=device)
for i, seq in enumerate(input_ids):
attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
attention_mask = attention_mask.bool()
return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We adopt left-padding in both training and evaluation.
"""
if isinstance(features[0]["input_ids"], torch.Tensor):
input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
else:
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
if "labels" in features[0]:
if isinstance(features[0]["labels"], torch.Tensor):
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
else:
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
input_ids = input_ids + labels # pad them to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
batch = {}
if "labels" in features[0]:
input_ids, labels = input_ids.split(len(features), dim=0)
labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
batch["labels"] = labels
batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return batch

205
src/utils/other.py Normal file
View File

@ -0,0 +1,205 @@
import os
import sys
import json
import torch
import logging
from typing import Dict, List, Optional
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from peft.utils.other import WEIGHTS_NAME
IGNORE_INDEX = -100
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(sys.stdout)]
)
def get_logger(name: str) -> logging.Logger:
return logging.getLogger(name)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True).
# Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting
) -> PreTrainedModel:
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if hasattr(model, output_embedding_layer_name):
output_embedding_layer = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x):
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
return model
def print_trainable_params(model: torch.nn.Module) -> None:
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights."
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
r"""
Configures device map for LLaMA.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
"""
num_layers = 28
layers_per_gpu = 30 / num_gpus
device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0}
added_layers = 2
target_gpu = 0
for i in range(num_layers):
if added_layers >= layers_per_gpu:
target_gpu += 1
added_layers = 0
assert target_gpu < num_gpus
device_map[f"model.layers.{i}"] = target_gpu
added_layers += 1
return device_map
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
"""
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
for key in keys:
steps, metrics = [], []
for i in range(len(data["log_history"])):
if key in data["log_history"][i]:
steps.append(data["log_history"][i]["step"])
metrics.append(data["log_history"][i][key])
if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.")
continue
plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed")
plt.title("training {} of {}".format(key, training_args.output_dir))
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))

51
src/utils/pairwise.py Normal file
View File

@ -0,0 +1,51 @@
import torch
from typing import Dict, Sequence, Union
from .data_collator import DataCollatorForLLaMA
from .peft_trainer import PeftTrainer
from .other import get_logger
logger = get_logger(__name__)
class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
return super().__call__(features)
class PairwiseTrainerForLLaMA(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(self, model, inputs, return_outputs=False):
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
outputs = {"r_accept": r_accept, "r_reject": r_reject}
return (loss, outputs) if return_outputs else loss

78
src/utils/peft_trainer.py Normal file
View File

@ -0,0 +1,78 @@
import os
import torch
from typing import Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from peft.utils.other import WEIGHTS_NAME
from .config import FinetuningArguments
from .other import (
get_logger,
get_state_dict,
load_trainable_params,
load_valuehead_params,
FINETUNING_ARGS_NAME,
VALUE_HEAD_FILE_NAME
)
logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if hasattr(model, "pretrained_model"): # for models with valuehead
backbone_model = getattr(model, "pretrained_model")
else:
backbone_model = model
if hasattr(backbone_model, "peft_config"): # peft methods
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
else:
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if hasattr(model, "peft_config"): # peft methods
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
else:
load_trainable_params(model, self.state.best_model_checkpoint)
if hasattr(model, "v_head"):
load_valuehead_params(model, self.state.best_model_checkpoint)

241
src/utils/ppo.py Normal file
View File

@ -0,0 +1,241 @@
import os
import math
import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Literal, Optional, Tuple
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TrainerState
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
from .peft_trainer import PeftTrainer
from .config import FinetuningArguments
from .other import (
AverageMeter,
get_logger,
get_logits_processor
)
logger = get_logger(__name__)
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
if target == "reward": # save original head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
"summary.weight": getattr(model, "{}_head_weight".format(target)),
"summary.bias": getattr(model, "{}_head_bias".format(target))
})
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
def ppo_train(self, max_target_length: int) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
len_dataloader = len(self.dataloader)
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
# Keyword arguments for `model.generate`
gen_kwargs = {
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
reward_meter = AverageMeter()
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
for _ in range(self.config.gradient_accumulation_steps):
batch = next(dataiter)
steps_trained += 1
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get response from LLaMA
query_tensors: torch.Tensor = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
queries: List[torch.Tensor] = []
responses: List[torch.Tensor] = []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
if response_length < 2: # make response have at least 2 tokens
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
else:
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1]]
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
# Run PPO step
unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"])
reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards))
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = {
"loss": round(loss_meter.avg, 4),
"reward": round(reward_meter.avg, 4),
"learning_rate": stats["ppo/learning_rate"],
"epoch": round(step / num_steps_per_epoch, 2)
}
print(logs)
logs["step"] = step
self.state.log_history.append(logs)
loss_meter.reset()
reward_meter.reset()
if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
@torch.no_grad()
def generate(
self,
inputs: Dict[str, torch.Tensor],
length_sampler: Callable = None,
return_prompt: bool = True,
**generation_kwargs,
) -> torch.Tensor:
r"""
Generates model's responses given queries.
Subclass and override to inject custom behavior.
"""
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
if not return_prompt and not self.is_encoder_decoder:
return response[:, inputs["input_ids"].size(1):]
return response
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
model: AutoModelForCausalLMWithValueHead,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(model_inputs["input_ids"])
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(int(bs / fbs)):
input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()}
input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences
logits, _, values = model(**input_kwargs)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(input_ids)
for j in range(fbs):
start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item()
masks[j][start:] = 1
if len(masks[j][start:]) < 2:
raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
all_logits.append(logits)
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1],
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
Subclass and override to inject custom behavior.
"""
if self.args.should_save:
self._save(output_dir)

96
src/utils/seq2seq.py Normal file
View File

@ -0,0 +1,96 @@
import os
import json
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.tokenization_utils import PreTrainedTokenizer
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from .peft_trainer import PeftTrainer
from .other import get_logger, IGNORE_INDEX
logger = get_logger(__name__)
@dataclass
class ComputeMetrics:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
"""
tokenizer: PreTrainedTokenizer
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(preds, labels):
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
if len(" ".join(hypothesis).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
else:
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
return {k: float(np.mean(v)) for k, v in score_dict.items()}
class Seq2SeqTrainerForLLaMA(PeftTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def save_predictions(
self,
predict_results: PredictionOutput,
tokenizer: PreTrainedTokenizer
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(preds, labels):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))

129
src/web_demo.py Normal file
View File

@ -0,0 +1,129 @@
# coding=utf-8
# Implements user interface in browser for LLaMA fine-tuned with PEFT.
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint
import torch
import mdtex2html
import gradio as gr
from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser
parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
device_map = auto_configure_device_map(torch.cuda.device_count())
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
inputs = tokenizer([input], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = {
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"num_beams": 1,
"max_length": max_length,
"repetition_penalty": 1.0
}
with torch.no_grad():
generation_output = model.generate(**inputs, **gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(input, response)]
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)