Merge branch 'main' into add_dataset_sample_num

This commit is contained in:
seanzhang-zhichen 2024-05-24 15:57:47 +08:00 committed by GitHub
commit 27cb51f7f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 756 additions and 513 deletions

View File

@ -69,12 +69,12 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
@ -160,6 +160,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
@ -284,11 +285,11 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.40.1 | | transformers | 4.37.2 | 4.41.0 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.14.3 | 2.19.1 |
| accelerate | 0.27.2 | 0.30.0 | | accelerate | 0.27.2 | 0.30.1 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.9.0 | 0.11.1 |
| trl | 0.8.1 | 0.8.6 | | trl | 0.8.2 | 0.8.6 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
@ -344,6 +345,8 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
Join [NPU user group](assets/wechat_npu.jpg).
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**.
| Requirement | Minimum | Recommend | | Requirement | Minimum | Recommend |
@ -356,7 +359,7 @@ To utilize Ascend NPU devices for (distributed) training and inference, you need
Docker image: Docker image:
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB: Coming soon - 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.

View File

@ -69,12 +69,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。 [24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
@ -160,6 +160,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
@ -284,11 +285,11 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.40.1 | | transformers | 4.37.2 | 4.41.0 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.14.3 | 2.19.1 |
| accelerate | 0.27.2 | 0.30.0 | | accelerate | 0.27.2 | 0.30.1 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.9.0 | 0.11.1 |
| trl | 0.8.1 | 0.8.6 | | trl | 0.8.2 | 0.8.6 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
@ -344,6 +345,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
加入 [NPU 用户群](assets/wechat_npu.jpg)。
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。 如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。
| 依赖项 | 至少 | 推荐 | | 依赖项 | 至少 | 推荐 |
@ -356,7 +359,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
Docker 镜像: Docker 镜像:
- 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB敬请期待 - 64GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

After

Width:  |  Height:  |  Size: 145 KiB

BIN
assets/wechat_npu.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

View File

@ -7,7 +7,7 @@
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name", "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name", "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"ranking": "是否为偏好数据集可选默认False", "ranking": "是否为偏好数据集可选默认False",
"subset": "数据集子集的名称可选默认None", "subset": "数据集子集的名称可选默认None",

View File

@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
features = datasets.Features( features = datasets.Features(
{ {
"instruction": datasets.Value("string"), "instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")), "chosen": datasets.Value("string"),
"rejected": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))), "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
} }
) )

View File

@ -8,6 +8,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine):
image: Optional["NDArray"] = None, image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]: if (
messages[0]["content"] = "<image>" + messages[0]["content"] processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
) )
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if processor is not None and image is not None: if pixel_values is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") gen_kwargs["pixel_values"] = pixel_values
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length

View File

@ -2,6 +2,7 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available
@ -17,7 +18,6 @@ if is_vllm_available():
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from numpy.typing import NDArray from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
@ -67,7 +67,7 @@ class VllmEngine(BaseEngine):
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2 self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values" engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>") engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
@ -92,14 +92,28 @@ class VllmEngine(BaseEngine):
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"] if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
) )
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1 use_beam_search: bool = self.generating_args["num_beams"] > 1
@ -144,13 +158,6 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True, skip_special_tokens=True,
) )
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
prompt=None, prompt=None,
sampling_params=sampling_params, sampling_params=sampling_params,

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple from typing import Any, Dict, Sequence
import torch import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
@ -11,21 +11,6 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
Data collator for pairwise data. Data collator for pairwise data.
""" """
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
r"""
Masks out the input ids except for the responses.
"""
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, prompt_len + answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r""" r"""
Pads batched data to the longest sequence in the batch. Pads batched data to the longest sequence in the batch.
@ -34,21 +19,22 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
concatenated_features = [] concatenated_features = []
label_positions = [] for key in ("chosen", "rejected"):
for key in ("chosen_ids", "rejected_ids"):
for feature in features: for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) target_feature = {
concatenated_features.append( "input_ids": feature["{}_input_ids".format(key)],
{ "attention_mask": feature["{}_attention_mask".format(key)],
"input_ids": feature["prompt_ids"] + feature[key], "labels": feature["{}_labels".format(key)],
"attention_mask": [1] * (prompt_len + answer_len), }
} if "pixel_values" in feature:
) target_feature["pixel_values"] = feature["pixel_values"]
label_positions.append((prompt_len, answer_len))
batch = super().__call__(concatenated_features) if "{}_token_type_ids".format(key) in feature:
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
return batch
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass @dataclass
@ -62,20 +48,25 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
kl_features = [] kl_features = []
kto_tags = [] kto_tags = []
for feature in features: for feature in features:
target_features.append( target_feature = {
{ "input_ids": feature["input_ids"],
"input_ids": feature["input_ids"], "attention_mask": feature["attention_mask"],
"attention_mask": feature["attention_mask"], "labels": feature["labels"],
"labels": feature["labels"], }
} kl_feature = {
) "input_ids": feature["kl_input_ids"],
kl_features.append( "attention_mask": feature["kl_attention_mask"],
{ "labels": feature["kl_labels"],
"input_ids": feature["kl_input_ids"], }
"attention_mask": feature["kl_attention_mask"], if "pixel_values" in feature:
"labels": feature["kl_labels"], target_feature["pixel_values"] = feature["pixel_values"]
}
) if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"]) kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features) batch = super().__call__(target_features)
@ -83,5 +74,8 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags) batch["kto_tags"] = torch.tensor(kto_tags)
return batch return batch

View File

@ -2,6 +2,7 @@ import inspect
import os import os
import numpy as np import numpy as np
from numpy.random import RandomState from numpy.random import RandomState
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Union
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
@ -180,12 +181,15 @@ def get_dataset(
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
exit(0) sys.exit(0)
if training_args.should_log: if training_args.should_log:
try: try:
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
else:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset return dataset

View File

@ -1,380 +1,25 @@
from functools import partial from functools import partial
from itertools import chain from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from ..extras.constants import IGNORE_INDEX from .processors.feedback import preprocess_feedback_dataset
from ..extras.logging import get_logger from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from ..extras.packages import is_pillow_available from .processors.pretrain import preprocess_pretrain_dataset
from .utils import Role from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
if is_pillow_available(): print_supervised_dataset_example,
from PIL import Image )
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments from ..hparams import DataArguments
from .template import Template from .template import Template
logger = get_logger(__name__)
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0]
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
messages = examples["prompt"][i] + examples["response"][i]
for source_ids, target_ids in template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
):
if data_args.train_on_prompt:
source_mask = source_ids
elif len(input_ids) != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
messages = examples["prompt"][i] + [examples["response"][i][0]]
else: # undesired example
kto_tag = False
messages = examples["prompt"][i] + [examples["response"][i][1]]
if kl_response[i][0]["content"]:
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
else:
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, kl_response_ids = template.encode_oneturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(
"labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
)
)
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@ -419,7 +64,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto": elif stage == "kto":
preprocess_func = partial( preprocess_func = partial(
preprocess_kto_dataset, preprocess_feedback_dataset,
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,

View File

@ -0,0 +1,110 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
messages = examples["prompt"][i] + [examples["response"][i][0]]
else: # undesired example
kto_tag = False
messages = examples["prompt"][i] + [examples["response"][i][1]]
if kl_response[i][0]["content"]:
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
else:
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, kl_response_ids = template.encode_oneturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs

View File

@ -0,0 +1,27 @@
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
# get paligemma token type ids for computing loss
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)

View File

@ -0,0 +1,109 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))

View File

@ -0,0 +1,36 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result

View File

@ -0,0 +1,137 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
messages = examples["prompt"][i] + examples["response"][i]
for source_ids, target_ids in template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i]
):
if data_args.train_on_prompt:
source_mask = source_ids
elif len(input_ids) != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))

View File

@ -0,0 +1,76 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IMAGE_TOKEN
from ...extras.logging import get_logger
from ..utils import Role
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
slot_items.append(placeholder) slot_items.append(placeholder)
if slot_pieces[1]: if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set): elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot: if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'") slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'") slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict): elif isinstance(slot, dict):
raise ValueError("Dict is not supported.") raise ValueError("Dict is not supported.")
@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}" jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}" jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}" jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}" jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}" jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja( assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer template.format_assistant.apply() + template.format_separator.apply(), tokenizer
@ -614,6 +616,9 @@ _register_template(
name="empty", name="empty",
format_user=StringFormatter(slots=["{{content}}"]), format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
force_system=True,
) )

View File

@ -22,6 +22,8 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_TOKEN = "<image>"
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
@ -714,6 +716,28 @@ register_model_group(
) )
register_model_group(
models={
"PaliGemma-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
},
},
vision=True,
)
register_model_group( register_model_group(
models={ models={
"Phi-1.5-1.3B": { "Phi-1.5-1.3B": {

View File

@ -65,7 +65,7 @@ def check_dependencies() -> None:
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") require_version("trl>=0.8.2", "To fix: pip install trl>=0.8.2")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:

View File

@ -145,7 +145,7 @@ class ModelArguments:
default=1, default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}, metadata={"help": "The file shard size (in GB) of the exported model."},
) )
export_device: str = field( export_device: Literal["cpu", "cuda"] = field(
default="cpu", default="cpu",
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
) )

View File

@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None: if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device(model_args.export_device)} model_args.device_map = {"": torch.device("cpu")}
else: else:
model_args.device_map = "auto" model_args.device_map = "auto"

View File

@ -4,7 +4,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch import torch
from transformers import BatchEncoding, Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities. Otherwise the average log probabilities.
""" """
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
all_logps = self.get_batch_logps( all_logps = self.get_batch_logps(
logits=all_logits, logits=all_logits,

View File

@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer):
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
with torch.no_grad(): with torch.no_grad():
kl_logits = model( kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
input_ids=batch["kl_input_ids"], if "pixel_values" in batch:
attention_mask=batch["kl_attention_mask"], kl_model_inputs["pixel_values"] = batch["pixel_values"]
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
target_logits = model( if "kl_token_type_ids" in batch:
input_ids=batch["input_ids"], kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
attention_mask=batch["attention_mask"],
return_dict=True, kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
use_cache=False,
).logits.to(torch.float32) model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "token_type_ids" in batch:
model_inputs["token_type_ids"] = batch["token_type_ids"]
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
target_logps = self.get_batch_logps( target_logps = self.get_batch_logps(
logits=target_logits, logits=target_logits,

View File

@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer):
r""" r"""
Computes the average log probabilities of the labels under the given logits. Computes the average log probabilities of the labels under the given logits.
""" """
all_logits: "torch.Tensor" = model( all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps( all_logps = self.get_batch_logps(
logits=all_logits, logits=all_logits,

View File

@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab: with gr.Accordion(open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True) reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
with gr.Column():
ppo_score_norm = gr.Checkbox()
ppo_whiten_rewards = gr.Checkbox()
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
elem_dict.update( elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model) dict(
rlhf_tab=rlhf_tab,
pref_beta=pref_beta,
pref_ftx=pref_ftx,
pref_loss=pref_loss,
reward_model=reward_model,
ppo_score_norm=ppo_score_norm,
ppo_whiten_rewards=ppo_whiten_rewards,
)
) )
with gr.Accordion(open=False) as galore_tab: with gr.Accordion(open=False) as galore_tab:

View File

@ -774,52 +774,52 @@ LOCALES = {
"label": "RLHF 参数设置", "label": "RLHF 参数设置",
}, },
}, },
"dpo_beta": { "pref_beta": {
"en": { "en": {
"label": "DPO beta", "label": "Beta value",
"info": "Value of the beta parameter in the DPO loss.", "info": "Value of the beta parameter in the loss.",
}, },
"ru": { "ru": {
"label": "DPO бета", "label": "Бета значение",
"info": "Значение параметра бета в функции потерь DPO.", "info": "Значение параметра бета в функции потерь.",
}, },
"zh": { "zh": {
"label": "DPO beta 参数", "label": "Beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。", "info": "损失函数中 beta 超参数大小。",
}, },
}, },
"dpo_ftx": { "pref_ftx": {
"en": { "en": {
"label": "DPO-ftx weight", "label": "Ftx gamma",
"info": "The weight of SFT loss in the DPO-ftx.", "info": "The weight of SFT loss in the final loss.",
}, },
"ru": { "ru": {
"label": "Вес DPO-ftx", "label": "Ftx гамма",
"info": "Вес функции потерь SFT в DPO-ftx.", "info": "Вес потери SFT в итоговой потере.",
}, },
"zh": { "zh": {
"label": "DPO-ftx 权重", "label": "Ftx gamma",
"info": "DPO-ftx 中 SFT 损失的权重大小。", "info": "损失函数中 SFT 损失的权重大小。",
}, },
}, },
"orpo_beta": { "pref_loss": {
"en": { "en": {
"label": "ORPO beta", "label": "Loss type",
"info": "Value of the beta parameter in the ORPO loss.", "info": "The type of the loss function.",
}, },
"ru": { "ru": {
"label": "ORPO бета", "label": "Тип потерь",
"info": "Значение параметра бета в функции потерь ORPO.", "info": "Тип функции потерь.",
}, },
"zh": { "zh": {
"label": "ORPO beta 参数", "label": "损失类型",
"info": "ORPO 损失函数中 beta 超参数大小", "info": "损失函数的类型",
}, },
}, },
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",
"info": "Adapter of the reward model for PPO training.", "info": "Adapter of the reward model in PPO training.",
}, },
"ru": { "ru": {
"label": "Модель вознаграждения", "label": "Модель вознаграждения",
@ -830,6 +830,34 @@ LOCALES = {
"info": "PPO 训练中奖励模型的适配器路径。", "info": "PPO 训练中奖励模型的适配器路径。",
}, },
}, },
"ppo_score_norm": {
"en": {
"label": "Score norm",
"info": "Normalizing scores in PPO training.",
},
"ru": {
"label": "Норма оценок",
"info": "Нормализация оценок в тренировке PPO.",
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中归一化奖励分数。",
},
},
"ppo_whiten_rewards": {
"en": {
"label": "Whiten rewards",
"info": "Whiten the rewards in PPO training.",
},
"ru": {
"label": "Белые вознаграждения",
"info": "Осветлите вознаграждения в обучении PPO.",
},
"zh": {
"label": "白化奖励",
"info": "PPO 训练中将奖励分数做白化处理。",
},
},
"galore_tab": { "galore_tab": {
"en": { "en": {
"label": "GaLore configurations", "label": "GaLore configurations",

View File

@ -145,11 +145,14 @@ class Runner:
plot_loss=True, plot_loss=True,
) )
# freeze config
if args["finetuning_type"] == "freeze": if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules") args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
elif args["finetuning_type"] == "lora":
# lora config
if args["finetuning_type"] == "lora":
args["lora_rank"] = get("train.lora_rank") args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha") args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout") args["lora_dropout"] = get("train.lora_dropout")
@ -163,6 +166,7 @@ class Runner:
if args["use_llama_pro"]: if args["use_llama_pro"]:
args["num_layer_trainable"] = get("train.num_layer_trainable") args["num_layer_trainable"] = get("train.num_layer_trainable")
# rlhf config
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = ",".join( args["reward_model"] = ",".join(
[ [
@ -171,31 +175,41 @@ class Runner:
] ]
) )
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0
args["top_p"] = 0.9
elif args["stage"] == "dpo": elif args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta") args["dpo_beta"] = get("train.pref_beta")
args["dpo_ftx"] = get("train.dpo_ftx") args["dpo_ftx"] = get("train.pref_ftx")
args["dpo_loss"] = get("train.pref_loss")
elif args["stage"] == "kto":
args["kto_beta"] = get("train.pref_beta")
args["kto_ftx"] = get("train.pref_ftx")
elif args["stage"] == "orpo": elif args["stage"] == "orpo":
args["orpo_beta"] = get("train.orpo_beta") args["orpo_beta"] = get("train.pref_beta")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
# galore config
if args["use_galore"]: if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank") args["galore_rank"] = get("train.galore_rank")
args["galore_update_interval"] = get("train.galore_update_interval") args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale") args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target") args["galore_target"] = get("train.galore_target")
# badam config
if args["use_badam"]: if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode") args["badam_mode"] = get("train.badam_mode")
args["badam_switch_mode"] = get("train.badam_switch_mode") args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio") args["badam_update_ratio"] = get("train.badam_update_ratio")
# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
return args return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: