support streaming data, fix #284 #274 #268

This commit is contained in:
hiyouga 2023-07-31 23:33:00 +08:00
parent 513e1f1ec9
commit 0411a4b3e1
28 changed files with 478 additions and 344 deletions

View File

@ -12,15 +12,15 @@
## Changelog
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
@ -175,6 +176,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
@ -197,6 +199,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
@ -220,6 +223,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
@ -278,6 +282,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
@ -296,6 +301,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
@ -311,6 +317,7 @@ If you want to predict the samples with empty responses, please kindly fill the
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@ -322,6 +329,7 @@ Visit `http://localhost:8000/docs` for API documentation.
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@ -331,6 +339,7 @@ python src/cli_demo.py \
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@ -340,6 +349,7 @@ python src/web_demo.py \
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export

View File

@ -12,15 +12,15 @@
## 更新日志
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--prompt_template llama2` 参数。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--prompt_template baichuan` 参数。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--prompt_template intern` 参数。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
@ -153,6 +153,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
@ -174,7 +175,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
@ -196,7 +198,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
@ -219,7 +222,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
@ -277,7 +281,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_en \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
@ -295,7 +300,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_en \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
@ -311,6 +317,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@ -322,6 +329,7 @@ python src/api_demo.py \
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@ -331,6 +339,7 @@ python src/cli_demo.py \
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
@ -340,6 +349,7 @@ python src/web_demo.py \
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export

View File

@ -1,42 +1,50 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
class ChatModel:
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments"
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(self.model)
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
self.model = dispatch_model(self.model, device_map)
else:
self.model = self.model.cuda()
self.template = get_template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix or ""
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.generating_args = generating_args
def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
@ -71,7 +79,11 @@ class ChatModel:
@torch.inference_mode()
def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
@ -82,7 +94,11 @@ class ChatModel:
@torch.inference_mode()
def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

View File

@ -1,40 +1,50 @@
import os
import hashlib
from typing import List
from typing import TYPE_CHECKING, List, Optional
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset(
model_args: ModelArguments,
data_args: DataArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
model_args: "ModelArguments",
data_args: "DataArguments"
) -> "Dataset":
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
all_datasets: List["Dataset"] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
@ -47,60 +57,56 @@ def get_dataset(
data_path = None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None)
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None)
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
raw_datasets = load_dataset(
dataset = load_dataset(
data_path,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None
)
dataset = raw_datasets[data_args.split]
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
prefix_data = [dataset_attr.source_prefix] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
("response_column", "response"),
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
for column_name in ["prompt", "query", "response", "history"]: # align datasets
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0]
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
else:
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
raise ValueError("Unknown mixing strategy.")

View File

@ -1,65 +1,63 @@
from typing import Literal
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
from itertools import chain
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template
from llmtuner.hparams import DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset(
dataset: Dataset,
tokenizer: PreTrainedTokenizer,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
dataset: "Dataset",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:
) -> "Dataset":
column_names = list(dataset.column_names or [])
template = get_template(data_args.template)
column_names = list(dataset.column_names)
prompt_template = get_template(data_args.prompt_template)
# support question with a single answer or multiple answers
def get_dialog(examples):
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
yield dialog
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = history if "history" in examples and examples["history"][i] else []
prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
yield query, response, history, prefix
def preprocess_pretrain_dataset(examples):
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
block_size = data_args.max_source_length - 1
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
for i in range(0, total_length, block_size)]
return {
"input_ids": result,
"labels": result.copy()
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
def preprocess_supervised_dataset(examples):
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
for dialog in get_dialog(examples):
for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], []
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@ -73,19 +71,20 @@ def preprocess_dataset(
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples):
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@ -93,6 +92,7 @@ def preprocess_dataset(
target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids)
return model_inputs
@ -100,12 +100,12 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@ -141,34 +141,44 @@ def preprocess_dataset(
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
elif stage == "sft":
if not training_args.predict_with_generate:
preprocess_function = preprocess_supervised_dataset
else:
preprocess_function = preprocess_unsupervised_dataset
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo":
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
batched=True,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
**kwargs
)
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
if stage == "pt":
print_unsupervised_dataset_example(dataset[0])
print_unsupervised_dataset_example(next(iter(dataset)))
elif stage == "sft":
print_supervised_dataset_example(dataset[0])
print_supervised_dataset_example(next(iter(dataset)))
elif stage == "rm":
print_pairwise_dataset_example(dataset[0])
print_pairwise_dataset_example(next(iter(dataset)))
elif stage == "ppo":
print_unsupervised_dataset_example(dataset[0])
print_unsupervised_dataset_example(next(iter(dataset)))
return dataset

View File

@ -1,13 +1,12 @@
from typing import Dict
from datasets import Dataset
from typing import TYPE_CHECKING, Dict
if TYPE_CHECKING:
from datasets import Dataset
def split_dataset(
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
if do_train:
if dev_ratio > 1e-6:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:

View File

@ -1,16 +1,13 @@
import os
import json
import time
from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from transformers import TrainerCallback
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
class LogCallback(TrainerCallback):
@ -20,13 +17,13 @@ class LogCallback(TrainerCallback):
self.start_time = time.time()
self.tracker = {}
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
self.start_time = time.time()
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
@ -35,7 +32,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
@ -43,7 +40,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
"""

View File

@ -1,12 +1,14 @@
import torch
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
r"""
@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None:
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
model: "PreTrainedModel",
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel:
) -> "PreTrainedModel":
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
@ -84,6 +94,9 @@ def prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name):
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
@ -92,11 +105,8 @@ def prepare_model_for_training(
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
new_output_layer = CastOutputToFloat(output_layer)
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
setattr(new_output_layer, "weight", output_layer.weight)
setattr(model, output_layer_name, new_output_layer)
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model

View File

@ -1,6 +1,6 @@
import os
import torch
from typing import Dict, Optional
from typing import Dict
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict()
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if (not trainable_only) or v.requires_grad:
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict

View File

@ -11,37 +11,46 @@ class Template:
use_history: bool
def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "",
eos_token: Optional[str] = "</s>"
) -> str:
r"""
Returns a string containing prompt without response.
"""
return "".join(self._format_example(query, history, prefix))
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
Returns a list containing prompt-response pairs.
"""
return self._format_example(query, history, prefix) + [resp]
result = self._format_example(query, history, prefix)
result[-1][-1] = resp
return result
def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]:
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")]
convs = []
for turn_idx, (user_query, bot_resp) in enumerate(history):
if turn_idx == 0:
convs.append(prefix + self.prompt.format(query=user_query))
convs.append(bot_resp)
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs[:-1] # drop last
history = history + [(query, "")]
convs = [
[(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
for turn_idx, (query_i, resp_i) in enumerate(history)
]
return convs
templates: Dict[str, Template] = {}
@ -103,7 +112,7 @@ register_template(
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt=" [INST] {query} [/INST] ",
sep="</s>",
sep="",
use_history=True
)
@ -131,7 +140,7 @@ register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
sep="",
use_history=True
)
@ -216,7 +225,7 @@ register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="</s>",
sep="",
use_history=True
)

View File

@ -1,6 +1,6 @@
import os
import json
from typing import List, Optional
from typing import List, Literal, Optional
from dataclasses import dataclass, field
@ -16,10 +16,10 @@ class DatasetAttr:
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
self.response_column = "output"
self.history_column = None
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass
@ -27,8 +27,11 @@ class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field(
default="alpaca_zh",
default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
@ -39,6 +42,18 @@ class DataArguments:
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
)
overwrite_cache: Optional[bool] = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}
@ -75,10 +90,6 @@ class DataArguments:
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
)
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
@ -111,9 +122,9 @@ class DataArguments:
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Arguments pertaining to which stage we are going to perform.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
default="sft",

View File

@ -1,7 +1,7 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.modeling_utils import PreTrainedModel
from peft import (
PeftModel,
TaskType,
@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
def init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
is_mergeable: bool
) -> PreTrainedModel:
) -> "PreTrainedModel":
r"""
Initializes the adapters.

View File

@ -1,6 +1,6 @@
import os
import torch
from typing import Literal, Optional, Tuple
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
AutoConfig,
@ -16,11 +16,13 @@ from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
@ -33,8 +35,8 @@ require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
@ -141,6 +143,9 @@ def load_model_and_tokenizer(
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model)
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
return model, tokenizer

View File

@ -19,20 +19,39 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
if args is not None:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@ -73,13 +92,22 @@ def get_train_args(
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.")
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
data_args.streaming = False
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
@ -106,17 +134,7 @@ def get_train_args(
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
if args is not None:
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
@ -128,7 +146,4 @@ def get_infer_args(
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
return model_args, data_args, finetuning_args, generating_args

View File

@ -1,16 +1,19 @@
import os
import torch
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
@ -21,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self._remove_log()
@ -42,31 +45,35 @@ class PeftTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
state_dict = state_dict or get_state_dict(model)
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
backbone_model = getattr(model, "pretrained_model")
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
else:
backbone_model = model
if isinstance(model, PreTrainedModelWrapper):
model_params, v_head_params = {}, {}
for name in state_dict.keys():
if name.startswith("pretrained_model."):
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
elif name.startswith("v_head."):
v_head_params[name.replace("v_head.", "")] = state_dict[name]
if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
safe_serialization=self.args.save_safetensors
)
backbone_model.config.use_cache = False
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
state_dict = model_params
model = model.pretrained_model
if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
model.config.use_cache = False
else:
logger.warning("No model to save.")
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
@ -76,16 +83,15 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if isinstance(backbone_model, PeftModel):
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if isinstance(model, PreTrainedModelWrapper):
model.v_head.load_state_dict(torch.load(
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
))
model = model.pretrained_model
if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
load_trainable_params(model, self.state.best_model_checkpoint)

View File

@ -2,21 +2,25 @@ import os
import math
import torch
from tqdm import tqdm
from typing import Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
from transformers import TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer
from trl.core import LengthSampler
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
@ -27,9 +31,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"""
def __init__(
self,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: List[LogCallback],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: List["LogCallback"],
**kwargs
):
PPOTrainer.__init__(self, **kwargs)

View File

@ -1,11 +1,13 @@
import torch
from typing import Dict, List, Literal, Optional, Tuple
from trl import AutoModelForCausalLMWithValueHead
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
model: "AutoModelForCausalLMWithValueHead",
layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}

View File

@ -2,26 +2,30 @@
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math
from typing import TYPE_CHECKING
from trl import PPOConfig
from torch.optim import AdamW
from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_ppo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")

View File

@ -1,24 +1,27 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")

View File

@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
for key in ("accept_ids", "reject_ids") for feature in features
]
return super().__call__(features)

View File

@ -1,13 +1,15 @@
import os
import json
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__)
@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
def compute_loss(
self,
model: PreTrainedModel,
model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer):
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.

View File

@ -2,25 +2,27 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")

View File

@ -1,7 +1,6 @@
import numpy as np
from dataclasses import dataclass
from typing import Dict, Sequence, Tuple, Union
from transformers.tokenization_utils import PreTrainedTokenizer
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@dataclass
class ComputeMetrics:
@ -16,7 +18,7 @@ class ComputeMetrics:
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
"""
tokenizer: PreTrainedTokenizer
tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r"""

View File

@ -3,13 +3,15 @@ import json
import torch
import numpy as np
import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__)
@ -81,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
def save_predictions(
self,
predict_results: PredictionOutput
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.

View File

@ -1,25 +1,28 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")

View File

@ -54,7 +54,7 @@ class WebChatModel(ChatModel):
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
template=template,
source_prefix=source_prefix
)
super().__init__(*get_infer_args(args))

View File

@ -111,7 +111,7 @@ class Runner:
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
@ -201,7 +201,7 @@ class Runner:
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),