Merge branch 'main' into add_dataset_sample_num
This commit is contained in:
commit
27cb51f7f8
17
README.md
17
README.md
|
@ -69,12 +69,12 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Changelog
|
||||
|
||||
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
||||
|
||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
||||
|
||||
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
@ -160,6 +160,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
|
@ -284,11 +285,11 @@ huggingface-cli login
|
|||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.0 |
|
||||
| transformers | 4.37.2 | 4.40.1 |
|
||||
| transformers | 4.37.2 | 4.41.0 |
|
||||
| datasets | 2.14.3 | 2.19.1 |
|
||||
| accelerate | 0.27.2 | 0.30.0 |
|
||||
| peft | 0.9.0 | 0.10.0 |
|
||||
| trl | 0.8.1 | 0.8.6 |
|
||||
| accelerate | 0.27.2 | 0.30.1 |
|
||||
| peft | 0.9.0 | 0.11.1 |
|
||||
| trl | 0.8.2 | 0.8.6 |
|
||||
|
||||
| Optional | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
|
@ -344,6 +345,8 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
|||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
Join [NPU user group](assets/wechat_npu.jpg).
|
||||
|
||||
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**.
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
|
@ -356,7 +359,7 @@ To utilize Ascend NPU devices for (distributed) training and inference, you need
|
|||
Docker image:
|
||||
|
||||
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB: Coming soon
|
||||
- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||
|
||||
|
|
17
README_zh.md
17
README_zh.md
|
@ -69,12 +69,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
## 更新日志
|
||||
|
||||
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
|
||||
|
||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
|
||||
|
||||
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
@ -160,6 +160,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
|
@ -284,11 +285,11 @@ huggingface-cli login
|
|||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.0 |
|
||||
| transformers | 4.37.2 | 4.40.1 |
|
||||
| transformers | 4.37.2 | 4.41.0 |
|
||||
| datasets | 2.14.3 | 2.19.1 |
|
||||
| accelerate | 0.27.2 | 0.30.0 |
|
||||
| peft | 0.9.0 | 0.10.0 |
|
||||
| trl | 0.8.1 | 0.8.6 |
|
||||
| accelerate | 0.27.2 | 0.30.1 |
|
||||
| peft | 0.9.0 | 0.11.1 |
|
||||
| trl | 0.8.2 | 0.8.6 |
|
||||
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
|
@ -344,6 +345,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||
|
||||
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
|
@ -356,7 +359,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||
Docker 镜像:
|
||||
|
||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB:敬请期待
|
||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 146 KiB After Width: | Height: | Size: 145 KiB |
Binary file not shown.
After Width: | Height: | Size: 146 KiB |
|
@ -7,7 +7,7 @@
|
|||
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
|
|
|
@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Sequence(datasets.Value("string")),
|
||||
"chosen": datasets.Value("string"),
|
||||
"rejected": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_TOKEN
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine):
|
|||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" + messages[0]["content"]
|
||||
if (
|
||||
processor is not None
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and IMAGE_TOKEN not in messages[0]["content"]
|
||||
): # llava-like models
|
||||
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
pixel_values = None
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
if processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
||||
|
@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import uuid
|
|||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_TOKEN
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count, infer_optim_dtype
|
||||
from ..extras.packages import is_vllm_available
|
||||
|
@ -17,7 +18,6 @@ if is_vllm_available():
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
@ -67,7 +67,7 @@ class VllmEngine(BaseEngine):
|
|||
patch_size = config.vision_config.patch_size
|
||||
self.image_feature_size = (image_size // patch_size) ** 2
|
||||
engine_args["image_input_type"] = "pixel_values"
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
|
@ -92,14 +92,28 @@ class VllmEngine(BaseEngine):
|
|||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
if (
|
||||
self.processor is not None
|
||||
and image is not None
|
||||
and not hasattr(self.processor, "image_seq_length")
|
||||
and IMAGE_TOKEN not in messages[0]["content"]
|
||||
): # llava-like models
|
||||
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
||||
|
@ -144,13 +158,6 @@ class VllmEngine(BaseEngine):
|
|||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
result_generator = self.model.generate(
|
||||
prompt=None,
|
||||
sampling_params=sampling_params,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
@ -11,21 +11,6 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||
r"""
|
||||
Masks out the input ids except for the responses.
|
||||
"""
|
||||
padded_labels = []
|
||||
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||
if self.tokenizer.padding_side == "left":
|
||||
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||
else:
|
||||
start, end = prompt_len, prompt_len + answer_len
|
||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||
padded_tensor[start:end] = feature[start:end]
|
||||
padded_labels.append(padded_tensor)
|
||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
@ -34,21 +19,22 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||
the last n examples represent rejected examples.
|
||||
"""
|
||||
concatenated_features = []
|
||||
label_positions = []
|
||||
for key in ("chosen_ids", "rejected_ids"):
|
||||
for key in ("chosen", "rejected"):
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||
concatenated_features.append(
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (prompt_len + answer_len),
|
||||
target_feature = {
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
}
|
||||
)
|
||||
label_positions.append((prompt_len, answer_len))
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
batch = super().__call__(concatenated_features)
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
return super().__call__(concatenated_features)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -62,20 +48,25 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
target_features.append(
|
||||
{
|
||||
target_feature = {
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
)
|
||||
kl_features.append(
|
||||
{
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
)
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
|
@ -83,5 +74,8 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "token_type_ids" in batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
||||
|
|
|
@ -2,6 +2,7 @@ import inspect
|
|||
import os
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
@ -180,12 +181,15 @@ def get_dataset(
|
|||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
else:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -1,380 +1,25 @@
|
|||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_pillow_available
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
from .processors.feedback import preprocess_feedback_dataset
|
||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
||||
from .processors.pretrain import preprocess_pretrain_dataset
|
||||
from .processors.supervised import (
|
||||
preprocess_packed_supervised_dataset,
|
||||
preprocess_supervised_dataset,
|
||||
print_supervised_dataset_example,
|
||||
)
|
||||
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
# process visual inputs (currently only supports a single image)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0]
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
rejected_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(
|
||||
"labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
|
@ -419,7 +64,7 @@ def get_preprocess_and_print_func(
|
|||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "kto":
|
||||
preprocess_func = partial(
|
||||
preprocess_kto_dataset,
|
||||
preprocess_feedback_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
|
||||
from ...extras.logging import get_logger
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
model_inputs["kl_token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
|
@ -0,0 +1,27 @@
|
|||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
# process visual inputs (currently only supports a single image)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
||||
# get paligemma token type ids for computing loss
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
|
@ -0,0 +1,109 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
|
||||
from ...extras.logging import get_logger
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {
|
||||
"chosen_input_ids": [],
|
||||
"chosen_attention_mask": [],
|
||||
"chosen_labels": [],
|
||||
"rejected_input_ids": [],
|
||||
"rejected_attention_mask": [],
|
||||
"rejected_labels": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"] = []
|
||||
model_inputs["rejected_token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
rejected_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
model_inputs["chosen_labels"].append(chosen_labels)
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
|
||||
)
|
||||
model_inputs["rejected_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
|
@ -0,0 +1,36 @@
|
|||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
|
@ -0,0 +1,137 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
|
||||
from ...extras.logging import get_logger
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
|
@ -0,0 +1,76 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from ...extras.constants import IMAGE_TOKEN
|
||||
from ...extras.logging import get_logger
|
||||
from ..utils import Role
|
||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
|||
slot_items.append(placeholder)
|
||||
if slot_pieces[1]:
|
||||
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
||||
elif isinstance(slot, set):
|
||||
if "bos_token" in slot:
|
||||
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
||||
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
||||
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
|
||||
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
||||
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||
elif isinstance(slot, dict):
|
||||
raise ValueError("Dict is not supported.")
|
||||
|
@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
|||
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
||||
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
||||
jinja_template += "{% endif %}"
|
||||
|
||||
jinja_template += "{% if message['role'] == 'user' %}"
|
||||
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
||||
jinja_template += "{{ " + user_message + " }}"
|
||||
|
||||
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
||||
assistant_message = _convert_slots_to_jinja(
|
||||
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
||||
|
@ -614,6 +616,9 @@ _register_template(
|
|||
name="empty",
|
||||
format_user=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ FILEEXT2TYPE = {
|
|||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
IMAGE_TOKEN = "<image>"
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
@ -714,6 +716,28 @@ register_model_group(
|
|||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"PaliGemma-3B-pt-224": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
|
||||
},
|
||||
"PaliGemma-3B-pt-448": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
|
||||
},
|
||||
"PaliGemma-3B-pt-896": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
|
||||
},
|
||||
"PaliGemma-3B-mix-224": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
|
||||
},
|
||||
"PaliGemma-3B-mix-448": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
|
||||
},
|
||||
},
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {
|
||||
|
|
|
@ -65,7 +65,7 @@ def check_dependencies() -> None:
|
|||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||
require_version("trl>=0.8.2", "To fix: pip install trl>=0.8.2")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
|
|
@ -145,7 +145,7 @@ class ModelArguments:
|
|||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: str = field(
|
||||
export_device: Literal["cpu", "cuda"] = field(
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
|
||||
)
|
||||
|
|
|
@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None:
|
||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
model_args.device_map = {"": torch.device("cpu")}
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from types import MethodType
|
|||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
|
@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
logits=all_logits,
|
||||
|
|
|
@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
kl_logits = model(
|
||||
input_ids=batch["kl_input_ids"],
|
||||
attention_mask=batch["kl_attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
kl_model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
target_logits = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
if "kl_token_type_ids" in batch:
|
||||
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
|
||||
|
||||
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "token_type_ids" in batch:
|
||||
model_inputs["token_type_ids"] = batch["token_type_ids"]
|
||||
|
||||
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
target_logps = self.get_batch_logps(
|
||||
logits=target_logits,
|
||||
|
|
|
@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer):
|
|||
r"""
|
||||
Computes the average log probabilities of the labels under the given logits.
|
||||
"""
|
||||
all_logits: "torch.Tensor" = model(
|
||||
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
|
||||
).logits.to(torch.float32)
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
logits=all_logits,
|
||||
|
|
|
@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
with gr.Column():
|
||||
ppo_score_norm = gr.Checkbox()
|
||||
ppo_whiten_rewards = gr.Checkbox()
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||
input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
|
||||
elem_dict.update(
|
||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
|
||||
dict(
|
||||
rlhf_tab=rlhf_tab,
|
||||
pref_beta=pref_beta,
|
||||
pref_ftx=pref_ftx,
|
||||
pref_loss=pref_loss,
|
||||
reward_model=reward_model,
|
||||
ppo_score_norm=ppo_score_norm,
|
||||
ppo_whiten_rewards=ppo_whiten_rewards,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
|
|
|
@ -774,52 +774,52 @@ LOCALES = {
|
|||
"label": "RLHF 参数设置",
|
||||
},
|
||||
},
|
||||
"dpo_beta": {
|
||||
"pref_beta": {
|
||||
"en": {
|
||||
"label": "DPO beta",
|
||||
"info": "Value of the beta parameter in the DPO loss.",
|
||||
"label": "Beta value",
|
||||
"info": "Value of the beta parameter in the loss.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "DPO бета",
|
||||
"info": "Значение параметра бета в функции потерь DPO.",
|
||||
"label": "Бета значение",
|
||||
"info": "Значение параметра бета в функции потерь.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO beta 参数",
|
||||
"info": "DPO 损失函数中 beta 超参数大小。",
|
||||
"label": "Beta 参数",
|
||||
"info": "损失函数中 beta 超参数大小。",
|
||||
},
|
||||
},
|
||||
"dpo_ftx": {
|
||||
"pref_ftx": {
|
||||
"en": {
|
||||
"label": "DPO-ftx weight",
|
||||
"info": "The weight of SFT loss in the DPO-ftx.",
|
||||
"label": "Ftx gamma",
|
||||
"info": "The weight of SFT loss in the final loss.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Вес DPO-ftx",
|
||||
"info": "Вес функции потерь SFT в DPO-ftx.",
|
||||
"label": "Ftx гамма",
|
||||
"info": "Вес потери SFT в итоговой потере.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO-ftx 权重",
|
||||
"info": "DPO-ftx 中 SFT 损失的权重大小。",
|
||||
"label": "Ftx gamma",
|
||||
"info": "损失函数中 SFT 损失的权重大小。",
|
||||
},
|
||||
},
|
||||
"orpo_beta": {
|
||||
"pref_loss": {
|
||||
"en": {
|
||||
"label": "ORPO beta",
|
||||
"info": "Value of the beta parameter in the ORPO loss.",
|
||||
"label": "Loss type",
|
||||
"info": "The type of the loss function.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "ORPO бета",
|
||||
"info": "Значение параметра бета в функции потерь ORPO.",
|
||||
"label": "Тип потерь",
|
||||
"info": "Тип функции потерь.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "ORPO beta 参数",
|
||||
"info": "ORPO 损失函数中 beta 超参数大小。",
|
||||
"label": "损失类型",
|
||||
"info": "损失函数的类型。",
|
||||
},
|
||||
},
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Adapter of the reward model for PPO training.",
|
||||
"info": "Adapter of the reward model in PPO training.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Модель вознаграждения",
|
||||
|
@ -830,6 +830,34 @@ LOCALES = {
|
|||
"info": "PPO 训练中奖励模型的适配器路径。",
|
||||
},
|
||||
},
|
||||
"ppo_score_norm": {
|
||||
"en": {
|
||||
"label": "Score norm",
|
||||
"info": "Normalizing scores in PPO training.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Норма оценок",
|
||||
"info": "Нормализация оценок в тренировке PPO.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中归一化奖励分数。",
|
||||
},
|
||||
},
|
||||
"ppo_whiten_rewards": {
|
||||
"en": {
|
||||
"label": "Whiten rewards",
|
||||
"info": "Whiten the rewards in PPO training.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Белые вознаграждения",
|
||||
"info": "Осветлите вознаграждения в обучении PPO.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "白化奖励",
|
||||
"info": "PPO 训练中将奖励分数做白化处理。",
|
||||
},
|
||||
},
|
||||
"galore_tab": {
|
||||
"en": {
|
||||
"label": "GaLore configurations",
|
||||
|
|
|
@ -145,11 +145,14 @@ class Runner:
|
|||
plot_loss=True,
|
||||
)
|
||||
|
||||
# freeze config
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
|
||||
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
|
||||
elif args["finetuning_type"] == "lora":
|
||||
|
||||
# lora config
|
||||
if args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_alpha"] = get("train.lora_alpha")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
|
@ -163,6 +166,7 @@ class Runner:
|
|||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
|
||||
# rlhf config
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = ",".join(
|
||||
[
|
||||
|
@ -171,31 +175,41 @@ class Runner:
|
|||
]
|
||||
)
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
args["ppo_score_norm"] = get("train.ppo_score_norm")
|
||||
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
||||
args["top_k"] = 0
|
||||
args["top_p"] = 0.9
|
||||
elif args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
args["dpo_ftx"] = get("train.dpo_ftx")
|
||||
args["dpo_beta"] = get("train.pref_beta")
|
||||
args["dpo_ftx"] = get("train.pref_ftx")
|
||||
args["dpo_loss"] = get("train.pref_loss")
|
||||
elif args["stage"] == "kto":
|
||||
args["kto_beta"] = get("train.pref_beta")
|
||||
args["kto_ftx"] = get("train.pref_ftx")
|
||||
elif args["stage"] == "orpo":
|
||||
args["orpo_beta"] = get("train.orpo_beta")
|
||||
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
||||
args["orpo_beta"] = get("train.pref_beta")
|
||||
|
||||
# galore config
|
||||
if args["use_galore"]:
|
||||
args["galore_rank"] = get("train.galore_rank")
|
||||
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
# badam config
|
||||
if args["use_badam"]:
|
||||
args["badam_mode"] = get("train.badam_mode")
|
||||
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||
|
||||
# eval config
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
|
|
Loading…
Reference in New Issue