support ORPO
This commit is contained in:
parent
27776c3474
commit
17bf8a2c3a
|
@ -68,16 +68,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
|
||||||
|
|
||||||
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
||||||
|
|
||||||
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
|
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
|
||||||
|
|
||||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
|
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
|
||||||
|
|
||||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
||||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
@ -165,6 +167,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Use `--quantization_bit 4` argument to enable QLoRA.
|
> Use `--quantization_bit 4` argument to enable QLoRA.
|
||||||
|
|
|
@ -68,16 +68,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。
|
||||||
|
|
||||||
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
|
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
|
||||||
|
|
||||||
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`。
|
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。
|
||||||
|
|
||||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。
|
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
|
||||||
|
|
||||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
||||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
@ -165,6 +167,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
|
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
|
||||||
|
|
|
@ -34,6 +34,8 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||||
|
|
||||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
@ -84,6 +86,10 @@ For the preference datasets, the `response` column should be a string list whose
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Remember to set `"ranking": true` for the preference datasets.
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
The dataset in sharegpt format should follow the below format:
|
The dataset in sharegpt format should follow the below format:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|
|
@ -34,6 +34,8 @@
|
||||||
|
|
||||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
@ -84,6 +86,10 @@
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
添加偏好数据集需要额外指定 `"ranking": true`。
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
而 sharegpt 格式的数据集按照以下方式组织:
|
而 sharegpt 格式的数据集按照以下方式组织:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
- `pretrain.sh`: do pre-train (optional)
|
- `pretrain.sh`: do pre-train (optional)
|
||||||
- `sft.sh`: do supervised fine-tune
|
- `sft.sh`: do supervised fine-tuning
|
||||||
- `reward.sh`: do reward modeling (must after sft.sh)
|
- `reward.sh`: do reward modeling (must after sft.sh)
|
||||||
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
|
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
|
||||||
- `dpo.sh`: do DPO training (must after sft.sh)
|
- `dpo.sh`: do DPO training (must after sft.sh)
|
||||||
|
- `orpo.sh`: do ORPO training
|
||||||
- `predict.sh`: do predict (must after sft.sh and dpo.sh)
|
- `predict.sh`: do predict (must after sft.sh and dpo.sh)
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage orpo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/orpo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
|
@ -1,6 +1,15 @@
|
||||||
|
from .collator import PairwiseDataCollatorWithPadding
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||||
from .utils import Role, split_dataset
|
from .utils import Role, split_dataset
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
__all__ = [
|
||||||
|
"PairwiseDataCollatorWithPadding",
|
||||||
|
"get_dataset",
|
||||||
|
"Template",
|
||||||
|
"get_template_and_fix_tokenizer",
|
||||||
|
"templates",
|
||||||
|
"Role",
|
||||||
|
"split_dataset",
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for pairwise data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Masks out the input ids except for the responses.
|
||||||
|
"""
|
||||||
|
padded_labels = []
|
||||||
|
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||||
|
if self.tokenizer.padding_side == "left":
|
||||||
|
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||||
|
else:
|
||||||
|
start, end = prompt_len, prompt_len + answer_len
|
||||||
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||||
|
padded_tensor[start:end] = feature[start:end]
|
||||||
|
padded_labels.append(padded_tensor)
|
||||||
|
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||||
|
the last n examples represent rejected examples.
|
||||||
|
"""
|
||||||
|
concatenated_features = []
|
||||||
|
label_positions = []
|
||||||
|
for key in ("chosen_ids", "rejected_ids"):
|
||||||
|
for feature in features:
|
||||||
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||||
|
concatenated_features.append(
|
||||||
|
{
|
||||||
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
|
"attention_mask": [1] * (prompt_len + answer_len),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
label_positions.append((prompt_len, answer_len))
|
||||||
|
|
||||||
|
batch = super().__call__(concatenated_features)
|
||||||
|
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||||
|
return batch
|
|
@ -117,7 +117,6 @@ def get_dataset(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
# split: Optional[str] = "train", # TODO: add split
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
|
@ -138,6 +137,9 @@ def get_dataset(
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
all_datasets = []
|
||||||
for dataset_attr in get_dataset_list(data_args):
|
for dataset_attr in get_dataset_list(data_args):
|
||||||
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
|
|
|
@ -23,16 +23,18 @@ def preprocess_pretrain_dataset(
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
if not data_args.packing:
|
|
||||||
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
|
||||||
|
|
||||||
|
if not data_args.packing:
|
||||||
|
if data_args.template == "gemma":
|
||||||
|
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||||
|
|
||||||
|
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||||
|
else:
|
||||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
|
||||||
result = {
|
result = {
|
||||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
|
|
|
@ -44,7 +44,7 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||||
max_target_len = max(max_target_len, reserved_label_len)
|
max_target_len = max(max_target_len, reserved_label_len)
|
||||||
max_source_len = max_len - max_target_len
|
max_source_len = max_len - min(max_target_len, target_len)
|
||||||
return max_source_len, max_target_len
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,7 @@ class LogCallback(TrainerCallback):
|
||||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||||
reward=state.log_history[-1].get("reward", None),
|
reward=state.log_history[-1].get("reward", None),
|
||||||
|
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
||||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||||
epoch=state.log_history[-1].get("epoch", None),
|
epoch=state.log_history[-1].get("epoch", None),
|
||||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
|
|
|
@ -39,9 +39,12 @@ TRAINING_STAGES = {
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
"PPO": "ppo",
|
"PPO": "ppo",
|
||||||
"DPO": "dpo",
|
"DPO": "dpo",
|
||||||
|
"ORPO": "orpo",
|
||||||
"Pre-Training": "pt",
|
"Pre-Training": "pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||||
|
|
|
@ -110,6 +110,10 @@ class RLHFArguments:
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
)
|
)
|
||||||
|
orpo_beta: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
||||||
|
)
|
||||||
ppo_buffer_size: int = field(
|
ppo_buffer_size: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||||
|
@ -209,7 +213,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||||
)
|
)
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
|
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."},
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -74,7 +74,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||||
r"""
|
r"""
|
||||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
|
||||||
|
@ -85,8 +85,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
return -all_logps
|
return -all_logps
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||||
|
|
||||||
all_logits = model(
|
all_logits = model(
|
||||||
|
@ -107,9 +107,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
def get_batch_loss_metrics(
|
def get_batch_loss_metrics(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
batch: Dict[str, torch.Tensor],
|
batch: Dict[str, "torch.Tensor"],
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||||
r"""
|
r"""
|
||||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
"""
|
"""
|
||||||
|
@ -142,21 +142,22 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
reference_chosen_logps,
|
reference_chosen_logps,
|
||||||
reference_rejected_logps,
|
reference_rejected_logps,
|
||||||
)
|
)
|
||||||
|
batch_loss = losses.mean()
|
||||||
if self.ftx_gamma > 1e-6:
|
if self.ftx_gamma > 1e-6:
|
||||||
batch_size = batch["input_ids"].size(0) // 2
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
|
||||||
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
|
||||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
|
||||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
|
||||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
|
||||||
|
|
||||||
return losses.mean(), metrics
|
return batch_loss, metrics
|
||||||
|
|
|
@ -2,13 +2,12 @@
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..utils import create_modelcard_and_push, create_ref_model
|
from ..utils import create_modelcard_and_push, create_ref_model
|
||||||
from .collator import DPODataCollatorWithPadding
|
|
||||||
from .trainer import CustomDPOTrainer
|
from .trainer import CustomDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +28,7 @@ def run_dpo(
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = PairwiseDataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8,
|
pad_to_multiple_of=8,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
|
@ -64,7 +63,7 @@ def run_dpo(
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .workflow import run_orpo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_orpo"]
|
|
@ -0,0 +1,150 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import Trainer
|
||||||
|
from trl import DPOTrainer
|
||||||
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
class CustomORPOTrainer(DPOTrainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union["PreTrainedModel", "torch.nn.Module"],
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
disable_dropout: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if disable_dropout:
|
||||||
|
disable_dropout_in_model(model)
|
||||||
|
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
self.reference_free = False
|
||||||
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
|
self.generate_during_eval = False # disable at evaluation
|
||||||
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
|
self.padding_value = 0
|
||||||
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
self.precompute_ref_log_probs = False
|
||||||
|
self._precomputed_train_ref_log_probs = False
|
||||||
|
self._precomputed_eval_ref_log_probs = False
|
||||||
|
self._peft_has_been_casted_to_bf16 = False
|
||||||
|
|
||||||
|
self.beta = finetuning_args.orpo_beta
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
|
|
||||||
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||||
|
"""
|
||||||
|
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||||
|
return -all_logps
|
||||||
|
|
||||||
|
# Borrowed from:
|
||||||
|
# https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/orpo_trainer.py#L592
|
||||||
|
def odds_ratio_loss(
|
||||||
|
self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor"
|
||||||
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Computes ORPO's odds ratio (OR) loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||||
|
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of five tensors: (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
||||||
|
log_odds = (chosen_logps - rejected_logps) - (
|
||||||
|
torch.log(1 - torch.exp(chosen_logps)) - torch.log(1 - torch.exp(rejected_logps))
|
||||||
|
)
|
||||||
|
ratio = F.logsigmoid(log_odds)
|
||||||
|
losses = self.beta * ratio
|
||||||
|
|
||||||
|
chosen_rewards = self.beta * chosen_logps.detach()
|
||||||
|
rejected_rewards = self.beta * rejected_logps.detach()
|
||||||
|
|
||||||
|
return losses, chosen_rewards, rejected_rewards, ratio, log_odds
|
||||||
|
|
||||||
|
def concatenated_forward(
|
||||||
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
|
all_logits = model(
|
||||||
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
|
||||||
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
|
all_logps = self.get_batch_logps(
|
||||||
|
all_logits,
|
||||||
|
batch["labels"],
|
||||||
|
average_log_prob=False,
|
||||||
|
label_pad_token_id=self.label_pad_token_id,
|
||||||
|
)
|
||||||
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
|
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||||
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
batch: Dict[str, "torch.Tensor"],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||||
|
r"""
|
||||||
|
Computes the ORPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
||||||
|
chosen_logps, rejected_logps
|
||||||
|
)
|
||||||
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
|
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||||
|
sft_loss = self.sft_loss(chosen_logits, chosen_labels)
|
||||||
|
batch_loss = (sft_loss - losses).mean()
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||||
|
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||||
|
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||||
|
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||||
|
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
|
||||||
|
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
|
||||||
|
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
|
||||||
|
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
|
||||||
|
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
|
||||||
|
metrics["{}log_odds_ratio".format(prefix)] = log_odds_ratio.detach().cpu().mean()
|
||||||
|
metrics["{}log_odds_chosen".format(prefix)] = log_odds_chosen.detach().cpu().mean()
|
||||||
|
|
||||||
|
return batch_loss, metrics
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.ploting import plot_loss
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
from ...model import load_model, load_tokenizer
|
||||||
|
from ..utils import create_modelcard_and_push
|
||||||
|
from .trainer import CustomORPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
def run_orpo(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
|
):
|
||||||
|
tokenizer = load_tokenizer(model_args)
|
||||||
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
|
data_collator = PairwiseDataCollatorWithPadding(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
pad_to_multiple_of=8,
|
||||||
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update arguments
|
||||||
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = CustomORPOTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**split_dataset(dataset, data_args, training_args),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
if training_args.do_eval:
|
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
|
@ -2,13 +2,12 @@
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..utils import create_modelcard_and_push
|
from ..utils import create_modelcard_and_push
|
||||||
from .collator import PairwiseDataCollatorWithPadding
|
|
||||||
from .metric import compute_accuracy
|
from .metric import compute_accuracy
|
||||||
from .trainer import PairwiseTrainer
|
from .trainer import PairwiseTrainer
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..extras.logging import get_logger
|
||||||
from ..hparams import get_infer_args, get_train_args
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from ..model import load_model_and_tokenizer
|
from ..model import load_model_and_tokenizer
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
|
from .orpo import run_orpo
|
||||||
from .ppo import run_ppo
|
from .ppo import run_ppo
|
||||||
from .pt import run_pt
|
from .pt import run_pt
|
||||||
from .rm import run_rm
|
from .rm import run_rm
|
||||||
|
@ -36,6 +37,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
elif finetuning_args.stage == "dpo":
|
elif finetuning_args.stage == "dpo":
|
||||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif finetuning_args.stage == "orpo":
|
||||||
|
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown task.")
|
raise ValueError("Unknown task.")
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..extras.constants import (
|
||||||
DEFAULT_MODULE,
|
DEFAULT_MODULE,
|
||||||
DEFAULT_TEMPLATE,
|
DEFAULT_TEMPLATE,
|
||||||
PEFT_METHODS,
|
PEFT_METHODS,
|
||||||
|
STAGES_USE_PAIR_DATA,
|
||||||
SUPPORTED_MODELS,
|
SUPPORTED_MODELS,
|
||||||
TRAINING_STAGES,
|
TRAINING_STAGES,
|
||||||
DownloadSource,
|
DownloadSource,
|
||||||
|
@ -127,7 +128,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||||
|
|
||||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
||||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
|
||||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||||
return gr.Dropdown(value=[], choices=datasets)
|
return gr.Dropdown(value=[], choices=datasets)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue