support ORPO

This commit is contained in:
hiyouga 2024-03-31 18:29:50 +08:00
parent 27776c3474
commit 17bf8a2c3a
22 changed files with 395 additions and 47 deletions

View File

@ -68,16 +68,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv! [24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage. [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
<details><summary>Full Changelog</summary>
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage. [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage. [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
<details><summary>Full Changelog</summary>
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
@ -165,6 +167,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE] > [!NOTE]
> Use `--quantization_bit 4` argument to enable QLoRA. > Use `--quantization_bit 4` argument to enable QLoRA.

View File

@ -68,16 +68,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看! [24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora` [24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`
<details><summary>展开日志</summary>
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus` [24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore` [24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`
<details><summary>展开日志</summary>
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
@ -165,6 +167,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE] > [!NOTE]
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。 > 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。

View File

@ -34,6 +34,8 @@ If you are using a custom dataset, please provide your dataset definition in the
Given above, you can use the custom dataset via specifying `--dataset dataset_name`. Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
----
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format: Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
```json ```json
@ -84,6 +86,10 @@ For the preference datasets, the `response` column should be a string list whose
} }
``` ```
Remember to set `"ranking": true` for the preference datasets.
----
The dataset in sharegpt format should follow the below format: The dataset in sharegpt format should follow the below format:
```json ```json

View File

@ -34,6 +34,8 @@
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。 添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
----
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织: 该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
```json ```json
@ -84,6 +86,10 @@
} }
``` ```
添加偏好数据集需要额外指定 `"ranking": true`
----
而 sharegpt 格式的数据集按照以下方式组织: 而 sharegpt 格式的数据集按照以下方式组织:
```json ```json

View File

@ -1,8 +1,9 @@
Usage: Usage:
- `pretrain.sh`: do pre-train (optional) - `pretrain.sh`: do pre-train (optional)
- `sft.sh`: do supervised fine-tune - `sft.sh`: do supervised fine-tuning
- `reward.sh`: do reward modeling (must after sft.sh) - `reward.sh`: do reward modeling (must after sft.sh)
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh) - `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
- `dpo.sh`: do DPO training (must after sft.sh) - `dpo.sh`: do DPO training (must after sft.sh)
- `orpo.sh`: do ORPO training
- `predict.sh`: do predict (must after sft.sh and dpo.sh) - `predict.sh`: do predict (must after sft.sh and dpo.sh)

View File

@ -0,0 +1,32 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage orpo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset comparison_gpt4_en \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora/orpo \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--max_samples 1000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@ -1,6 +1,15 @@
from .collator import PairwiseDataCollatorWithPadding
from .loader import get_dataset from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset from .utils import Role, split_dataset
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"] __all__ = [
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",
"templates",
"Role",
"split_dataset",
]

View File

@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
r"""
Masks out the input ids except for the responses.
"""
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, prompt_len + answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append(
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len),
}
)
label_positions.append((prompt_len, answer_len))
batch = super().__call__(concatenated_features)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@ -117,7 +117,6 @@ def get_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template) template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
@ -138,6 +137,9 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
all_datasets = [] all_datasets = []
for dataset_attr in get_dataset_list(data_args): for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)

View File

@ -23,16 +23,18 @@ def preprocess_pretrain_dataset(
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = { result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items() for k, t in concatenated_examples.items()

View File

@ -44,7 +44,7 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len))) max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len) max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - max_target_len max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len return max_source_len, max_target_len

View File

@ -134,6 +134,7 @@ class LogCallback(TrainerCallback):
eval_loss=state.log_history[-1].get("eval_loss", None), eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None), predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None), reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None), learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,

View File

@ -39,9 +39,12 @@ TRAINING_STAGES = {
"Reward Modeling": "rm", "Reward Modeling": "rm",
"PPO": "ppo", "PPO": "ppo",
"DPO": "dpo", "DPO": "dpo",
"ORPO": "orpo",
"Pre-Training": "pt", "Pre-Training": "pt",
} }
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"

View File

@ -110,6 +110,10 @@ class RLHFArguments:
default=0.0, default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
) )
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field( ppo_buffer_size: int = field(
default=1, default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
@ -209,7 +213,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
) )
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field( stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},
) )

View File

@ -74,7 +74,7 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor: def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r""" r"""
Computes supervised cross-entropy loss of given labels under the given logits. Computes supervised cross-entropy loss of given labels under the given logits.
@ -85,8 +85,8 @@ class CustomDPOTrainer(DPOTrainer):
return -all_logps return -all_logps
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model( all_logits = model(
@ -107,9 +107,9 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, torch.Tensor], batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train", train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r""" r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
@ -142,21 +142,22 @@ class CustomDPOTrainer(DPOTrainer):
reference_chosen_logps, reference_chosen_logps,
reference_rejected_logps, reference_rejected_logps,
) )
batch_loss = losses.mean()
if self.ftx_gamma > 1e-6: if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2 batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0) chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels) batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float() reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics return batch_loss, metrics

View File

@ -2,13 +2,12 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import get_dataset, split_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push, create_ref_model from ..utils import create_modelcard_and_push, create_ref_model
from .collator import DPODataCollatorWithPadding
from .trainer import CustomDPOTrainer from .trainer import CustomDPOTrainer
@ -29,7 +28,7 @@ def run_dpo(
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DPODataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
@ -64,7 +63,7 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -0,0 +1,4 @@
from .workflow import run_orpo
__all__ = ["run_orpo"]

View File

@ -0,0 +1,150 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import FinetuningArguments
class CustomORPOTrainer(DPOTrainer):
def __init__(
self,
model: Union["PreTrainedModel", "torch.nn.Module"],
finetuning_args: "FinetuningArguments",
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
disable_dropout_in_model(model)
self.finetuning_args = finetuning_args
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.beta = finetuning_args.orpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
# Borrowed from:
# https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/orpo_trainer.py#L592
def odds_ratio_loss(
self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes ORPO's odds ratio (OR) loss.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of five tensors: (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen).
"""
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
log_odds = (chosen_logps - rejected_logps) - (
torch.log(1 - torch.exp(chosen_logps)) - torch.log(1 - torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
losses = self.beta * ratio
chosen_rewards = self.beta * chosen_logps.detach()
rejected_rewards = self.beta * rejected_logps.detach()
return losses, chosen_rewards, rejected_rewards, ratio, log_odds
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
all_logits = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the ORPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
chosen_logps, rejected_logps
)
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
sft_loss = self.sft_loss(chosen_logits, chosen_labels)
batch_loss = (sft_loss - losses).mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
metrics["{}log_odds_ratio".format(prefix)] = log_odds_ratio.detach().cpu().mean()
metrics["{}log_odds_chosen".format(prefix)] = log_odds_chosen.detach().cpu().mean()
return batch_loss, metrics

View File

@ -0,0 +1,68 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push
from .trainer import CustomORPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments
def run_orpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
# Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer
trainer = CustomORPOTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args),
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@ -2,13 +2,12 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import get_dataset, split_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
from ...extras.callbacks import FixValueHeadModelCallback from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..utils import create_modelcard_and_push from ..utils import create_modelcard_and_push
from .collator import PairwiseDataCollatorWithPadding
from .metric import compute_accuracy from .metric import compute_accuracy
from .trainer import PairwiseTrainer from .trainer import PairwiseTrainer

View File

@ -9,6 +9,7 @@ from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model_and_tokenizer from ..model import load_model_and_tokenizer
from .dpo import run_dpo from .dpo import run_dpo
from .orpo import run_orpo
from .ppo import run_ppo from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
@ -36,6 +37,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo": elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
else: else:
raise ValueError("Unknown task.") raise ValueError("Unknown task.")

View File

@ -11,6 +11,7 @@ from ..extras.constants import (
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
PEFT_METHODS, PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS, SUPPORTED_MODELS,
TRAINING_STAGES, TRAINING_STAGES,
DownloadSource, DownloadSource,
@ -127,7 +128,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(value=[], choices=datasets) return gr.Dropdown(value=[], choices=datasets)