diff --git a/README.md b/README.md index 601f67a9..fcc96882 100644 --- a/README.md +++ b/README.md @@ -69,12 +69,12 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion. + [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. -[24/05/13] We supported fine-tuning the **Yi-1.5** series models. -
Full Changelog [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. @@ -160,6 +160,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | +| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | @@ -284,11 +285,11 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.8 | 3.10 | | torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.40.1 | +| transformers | 4.37.2 | 4.41.0 | | datasets | 2.14.3 | 2.19.1 | -| accelerate | 0.27.2 | 0.30.0 | -| peft | 0.9.0 | 0.10.0 | -| trl | 0.8.1 | 0.8.6 | +| accelerate | 0.27.2 | 0.30.1 | +| peft | 0.9.0 | 0.11.1 | +| trl | 0.8.2 | 0.8.6 | | Optional | Minimum | Recommend | | ------------ | ------- | --------- | @@ -344,6 +345,8 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
For Ascend NPU users +Join [NPU user group](assets/wechat_npu.jpg). + To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. | Requirement | Minimum | Recommend | @@ -356,7 +359,7 @@ To utilize Ascend NPU devices for (distributed) training and inference, you need Docker image: - 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) -- 64GB: Coming soon +- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. diff --git a/README_zh.md b/README_zh.md index 27b122b0..2e0b4f34 100644 --- a/README_zh.md +++ b/README_zh.md @@ -69,12 +69,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 +[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。 + [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。 -[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。 -
展开日志 [24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 @@ -160,6 +160,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | +| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | @@ -284,11 +285,11 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.8 | 3.10 | | torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.40.1 | +| transformers | 4.37.2 | 4.41.0 | | datasets | 2.14.3 | 2.19.1 | -| accelerate | 0.27.2 | 0.30.0 | -| peft | 0.9.0 | 0.10.0 | -| trl | 0.8.1 | 0.8.6 | +| accelerate | 0.27.2 | 0.30.1 | +| peft | 0.9.0 | 0.11.1 | +| trl | 0.8.2 | 0.8.6 | | 可选项 | 至少 | 推荐 | | ------------ | ------- | --------- | @@ -344,6 +345,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
昇腾 NPU 用户指南 +加入 [NPU 用户群](assets/wechat_npu.jpg)。 + 如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。 | 依赖项 | 至少 | 推荐 | @@ -356,7 +359,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl Docker 镜像: - 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) -- 64GB:敬请期待 +- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) 请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 63a44b5e..a5d44ade 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/assets/wechat_npu.jpg b/assets/wechat_npu.jpg new file mode 100644 index 00000000..353e7603 Binary files /dev/null and b/assets/wechat_npu.jpg differ diff --git a/data/README_zh.md b/data/README_zh.md index 1f6a5ba3..1427e48d 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -7,7 +7,7 @@ "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)", "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)", - "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", + "file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)", "formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)", "ranking": "是否为偏好数据集(可选,默认:False)", "subset": "数据集子集的名称(可选,默认:None)", diff --git a/data/hh_rlhf_en/hh_rlhf_en.py b/data/hh_rlhf_en/hh_rlhf_en.py index 1bc18f4f..aa108fa7 100644 --- a/data/hh_rlhf_en/hh_rlhf_en.py +++ b/data/hh_rlhf_en/hh_rlhf_en.py @@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder): features = datasets.Features( { "instruction": datasets.Value("string"), - "output": datasets.Sequence(datasets.Value("string")), + "chosen": datasets.Value("string"), + "rejected": datasets.Value("string"), "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))), } ) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 57cdc89a..5f0d02a7 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -8,6 +8,7 @@ import torch from transformers import GenerationConfig, TextIteratorStreamer from ..data import get_template_and_fix_tokenizer +from ..extras.constants import IMAGE_TOKEN from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer from .base_engine import BaseEngine, Response @@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine): image: Optional["NDArray"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: - if processor is not None and image is not None and "" not in messages[0]["content"]: - messages[0]["content"] = "" + messages[0]["content"] + if ( + processor is not None + and image is not None + and not hasattr(processor, "image_seq_length") + and IMAGE_TOKEN not in messages[0]["content"] + ): # llava-like models + messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or generating_args["default_system"] + pixel_values = None prompt_ids, _ = template.encode_oneturn( tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools ) + if processor is not None and image is not None: # add image features + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + batch_feature = image_processor(image, return_tensors="pt") + pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W) + if hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) @@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine): logits_processor=get_logits_processor(), ) - if processor is not None and image is not None: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"] - gen_kwargs["pixel_values"] = pixel_values.to(model.device) + if pixel_values is not None: + gen_kwargs["pixel_values"] = pixel_values return gen_kwargs, prompt_length diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 44b9651f..e424481f 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -2,6 +2,7 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer +from ..extras.constants import IMAGE_TOKEN from ..extras.logging import get_logger from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.packages import is_vllm_available @@ -17,7 +18,6 @@ if is_vllm_available(): if TYPE_CHECKING: - import torch from numpy.typing import NDArray from transformers.image_processing_utils import BaseImageProcessor @@ -67,7 +67,7 @@ class VllmEngine(BaseEngine): patch_size = config.vision_config.patch_size self.image_feature_size = (image_size // patch_size) ** 2 engine_args["image_input_type"] = "pixel_values" - engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("") + engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_feature_size"] = self.image_feature_size if getattr(config, "is_yi_vl_derived_model", None): @@ -92,14 +92,28 @@ class VllmEngine(BaseEngine): **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) - if self.processor is not None and image is not None and "" not in messages[0]["content"]: - messages[0]["content"] = "" * self.image_feature_size + messages[0]["content"] + + if ( + self.processor is not None + and image is not None + and not hasattr(self.processor, "image_seq_length") + and IMAGE_TOKEN not in messages[0]["content"] + ): # llava-like models + messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools ) + + if self.processor is not None and image is not None: # add image features + image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") + pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + else: + multi_modal_data = None + prompt_length = len(prompt_ids) use_beam_search: bool = self.generating_args["num_beams"] > 1 @@ -144,13 +158,6 @@ class VllmEngine(BaseEngine): skip_special_tokens=True, ) - if self.processor is not None and image is not None: - image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") - pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"] - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) - else: - multi_modal_data = None - result_generator = self.model.generate( prompt=None, sampling_params=sampling_params, diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 474d6a30..1dc8dd8d 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Dict, Sequence import torch from transformers import DataCollatorForSeq2Seq @@ -11,21 +11,6 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): Data collator for pairwise data. """ - def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: - r""" - Masks out the input ids except for the responses. - """ - padded_labels = [] - for feature, (prompt_len, answer_len) in zip(batch, positions): - if self.tokenizer.padding_side == "left": - start, end = feature.size(0) - answer_len, feature.size(0) - else: - start, end = prompt_len, prompt_len + answer_len - padded_tensor = self.label_pad_token_id * torch.ones_like(feature) - padded_tensor[start:end] = feature[start:end] - padded_labels.append(padded_tensor) - return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: r""" Pads batched data to the longest sequence in the batch. @@ -34,21 +19,22 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): the last n examples represent rejected examples. """ concatenated_features = [] - label_positions = [] - for key in ("chosen_ids", "rejected_ids"): + for key in ("chosen", "rejected"): for feature in features: - prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) - concatenated_features.append( - { - "input_ids": feature["prompt_ids"] + feature[key], - "attention_mask": [1] * (prompt_len + answer_len), - } - ) - label_positions.append((prompt_len, answer_len)) + target_feature = { + "input_ids": feature["{}_input_ids".format(key)], + "attention_mask": feature["{}_attention_mask".format(key)], + "labels": feature["{}_labels".format(key)], + } + if "pixel_values" in feature: + target_feature["pixel_values"] = feature["pixel_values"] - batch = super().__call__(concatenated_features) - batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) - return batch + if "{}_token_type_ids".format(key) in feature: + target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] + + concatenated_features.append(target_feature) + + return super().__call__(concatenated_features) @dataclass @@ -62,20 +48,25 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): kl_features = [] kto_tags = [] for feature in features: - target_features.append( - { - "input_ids": feature["input_ids"], - "attention_mask": feature["attention_mask"], - "labels": feature["labels"], - } - ) - kl_features.append( - { - "input_ids": feature["kl_input_ids"], - "attention_mask": feature["kl_attention_mask"], - "labels": feature["kl_labels"], - } - ) + target_feature = { + "input_ids": feature["input_ids"], + "attention_mask": feature["attention_mask"], + "labels": feature["labels"], + } + kl_feature = { + "input_ids": feature["kl_input_ids"], + "attention_mask": feature["kl_attention_mask"], + "labels": feature["kl_labels"], + } + if "pixel_values" in feature: + target_feature["pixel_values"] = feature["pixel_values"] + + if "token_type_ids" in feature: + target_feature["token_type_ids"] = feature["token_type_ids"] + kl_feature["token_type_ids"] = feature["kl_token_type_ids"] + + target_features.append(target_feature) + kl_features.append(kl_feature) kto_tags.append(feature["kto_tags"]) batch = super().__call__(target_features) @@ -83,5 +74,8 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] + if "token_type_ids" in batch: + batch["kl_token_type_ids"] = kl_batch["token_type_ids"] + batch["kto_tags"] = torch.tensor(kto_tags) return batch diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index a45e025f..5ce4392e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -2,6 +2,7 @@ import inspect import os import numpy as np from numpy.random import RandomState +import sys from typing import TYPE_CHECKING, Literal, Optional, Union from datasets import load_dataset, load_from_disk @@ -180,12 +181,15 @@ def get_dataset( logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) - exit(0) + sys.exit(0) if training_args.should_log: try: print_function(next(iter(dataset))) except StopIteration: - raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") return dataset diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 557678e6..336257ca 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -1,380 +1,25 @@ from functools import partial -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple -from ..extras.constants import IGNORE_INDEX -from ..extras.logging import get_logger -from ..extras.packages import is_pillow_available -from .utils import Role - - -if is_pillow_available(): - from PIL import Image +from .processors.feedback import preprocess_feedback_dataset +from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example +from .processors.pretrain import preprocess_pretrain_dataset +from .processors.supervised import ( + preprocess_packed_supervised_dataset, + preprocess_supervised_dataset, + print_supervised_dataset_example, +) +from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example if TYPE_CHECKING: - from numpy.typing import NDArray - from PIL.Image import Image as ImageObject from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils import PreTrainedTokenizer from ..hparams import DataArguments from .template import Template -logger = get_logger(__name__) - - -def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": - # process visual inputs (currently only supports a single image) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) - return image_processor(image, return_tensors="pt")["pixel_values"][0] - - -def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" -) -> Dict[str, List[List[int]]]: - # build grouped texts with format `X1 X2 X3 ...` if packing is enabled - text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] - - if not data_args.packing: - if data_args.template == "gemma": - text_examples = [tokenizer.bos_token + example for example in text_examples] - - result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) - else: - tokenized_examples = tokenizer(text_examples, add_special_tokens=False) - concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} - total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) - block_size = data_args.cutoff_len - total_length = (total_length // block_size) * block_size - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - if data_args.template == "gemma": - for i in range(len(result["input_ids"])): - result["input_ids"][i][0] = tokenizer.bos_token_id - - return result - - -def preprocess_supervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build inputs with format ` X Y ` and labels with format ` ... Y ` - # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] - - messages = examples["prompt"][i] + examples["response"][i] - input_ids, labels = [], [] - for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - return model_inputs - - -def preprocess_packed_supervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build inputs with format ` X1 Y1 X2 Y2 ` - # and labels with format ` ... Y1 ... Y2 ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - input_ids, labels = [], [] - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - messages = examples["prompt"][i] + examples["response"][i] - for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) - - return model_inputs - - -def preprocess_unsupervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build inputs with format ` X` and labels with format `Y ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] - - if len(examples["response"][i]) == 1: - messages = examples["prompt"][i] + examples["response"][i] - else: - messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] - - input_ids, labels = template.encode_oneturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - labels += [tokenizer.eos_token_id] - - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - return model_inputs - - -def preprocess_pairwise_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build input pairs with format ` X`, `Y1 ` and `Y2 ` - model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] - - chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] - rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] - prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, - chosen_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - _, rejected_ids = template.encode_oneturn( - tokenizer, - rejected_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - chosen_ids += [tokenizer.eos_token_id] - rejected_ids += [tokenizer.eos_token_id] - - model_inputs["prompt_ids"].append(prompt_ids) - model_inputs["chosen_ids"].append(chosen_ids) - model_inputs["rejected_ids"].append(rejected_ids) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - return model_inputs - - -def preprocess_kto_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs - kl_response = examples["response"][::-1] - model_inputs = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "kl_input_ids": [], - "kl_attention_mask": [], - "kl_labels": [], - "kto_tags": [], - } - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] - - if examples["response"][i][0]["content"]: # desired example - kto_tag = True - messages = examples["prompt"][i] + [examples["response"][i][0]] - else: # undesired example - kto_tag = False - messages = examples["prompt"][i] + [examples["response"][i][1]] - - if kl_response[i][0]["content"]: - kl_messages = examples["prompt"][i] + [kl_response[i][0]] - else: - kl_messages = examples["prompt"][i] + [kl_response[i][1]] - - prompt_ids, response_ids = template.encode_oneturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - _, kl_response_ids = template.encode_oneturn( - tokenizer, - kl_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - response_ids += [tokenizer.eos_token_id] - kl_response_ids += [tokenizer.eos_token_id] - - input_ids = prompt_ids + response_ids - labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids - kl_input_ids = prompt_ids + kl_response_ids - kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - model_inputs["kl_input_ids"].append(kl_input_ids) - model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) - model_inputs["kl_labels"].append(kl_labels) - model_inputs["kto_tags"].append(kto_tag) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) - undesirable_num = len(model_inputs["kto_tags"]) - desirable_num - if desirable_num == 0 or undesirable_num == 0: - logger.warning("Your dataset only has one preference type.") - - return model_inputs - - -def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - print("label_ids:\n{}".format(example["labels"])) - print( - "labels:\n{}".format( - tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) - ) - ) - - -def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("prompt_ids:\n{}".format(example["prompt_ids"])) - print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) - print("chosen_ids:\n{}".format(example["chosen_ids"])) - print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) - print("rejected_ids:\n{}".format(example["rejected_ids"])) - print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) - - -def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - - def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", @@ -419,7 +64,7 @@ def get_preprocess_and_print_func( print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) elif stage == "kto": preprocess_func = partial( - preprocess_kto_dataset, + preprocess_feedback_dataset, template=template, tokenizer=tokenizer, processor=processor, diff --git a/src/llamafactory/data/processors/__init__.py b/src/llamafactory/data/processors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py new file mode 100644 index 00000000..51db3e26 --- /dev/null +++ b/src/llamafactory/data/processors/feedback.py @@ -0,0 +1,110 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN +from ...extras.logging import get_logger +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_feedback_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs + kl_response = examples["response"][::-1] + model_inputs = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "kl_input_ids": [], + "kl_attention_mask": [], + "kl_labels": [], + "kto_tags": [], + } + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + model_inputs["kl_token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + if examples["response"][i][0]["content"]: # desired example + kto_tag = True + messages = examples["prompt"][i] + [examples["response"][i][0]] + else: # undesired example + kto_tag = False + messages = examples["prompt"][i] + [examples["response"][i][1]] + + if kl_response[i][0]["content"]: + kl_messages = examples["prompt"][i] + [kl_response[i][0]] + else: + kl_messages = examples["prompt"][i] + [kl_response[i][1]] + + prompt_ids, response_ids = template.encode_oneturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + _, kl_response_ids = template.encode_oneturn( + tokenizer, + kl_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + + if template.efficient_eos: + response_ids += [tokenizer.eos_token_id] + kl_response_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + input_ids = prompt_ids + response_ids + labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids + kl_input_ids = prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["kl_input_ids"].append(kl_input_ids) + model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) + model_inputs["kl_labels"].append(kl_labels) + model_inputs["kto_tags"].append(kto_tag) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor)) + + desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) + undesirable_num = len(model_inputs["kto_tags"]) - desirable_num + if desirable_num == 0 or undesirable_num == 0: + logger.warning("Your dataset only has one preference type.") + + return model_inputs diff --git a/src/llamafactory/data/processors/mm_utils.py b/src/llamafactory/data/processors/mm_utils.py new file mode 100644 index 00000000..abc7c4b2 --- /dev/null +++ b/src/llamafactory/data/processors/mm_utils.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING, List, Sequence + +from ...extras.packages import is_pillow_available + + +if is_pillow_available(): + from PIL import Image + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from PIL.Image import Image as ImageObject + from transformers import ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + +def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": + # process visual inputs (currently only supports a single image) + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) + return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W) + + +def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]: + # get paligemma token type ids for computing loss + image_seq_length = getattr(processor, "image_seq_length") + return [0] * image_seq_length + [1] * (input_len - image_seq_length) diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py new file mode 100644 index 00000000..ec0fb96e --- /dev/null +++ b/src/llamafactory/data/processors/pairwise.py @@ -0,0 +1,109 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN +from ...extras.logging import get_logger +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = { + "chosen_input_ids": [], + "chosen_attention_mask": [], + "chosen_labels": [], + "rejected_input_ids": [], + "rejected_attention_mask": [], + "rejected_labels": [], + } + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["chosen_token_type_ids"] = [] + model_inputs["rejected_token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] + rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] + prompt_ids, chosen_ids = template.encode_oneturn( + tokenizer, + chosen_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + _, rejected_ids = template.encode_oneturn( + tokenizer, + rejected_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["chosen_token_type_ids"].append( + get_paligemma_token_type_ids(len(chosen_input_ids), processor) + ) + model_inputs["rejected_token_type_ids"].append( + get_paligemma_token_type_ids(len(rejected_input_ids), processor) + ) + + return model_inputs + + +def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) + valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) + print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) + print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) + print("chosen_label_ids:\n{}".format(example["chosen_labels"])) + print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) + print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) + print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) + print("rejected_label_ids:\n{}".format(example["rejected_labels"])) + print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py new file mode 100644 index 00000000..3de0d1ac --- /dev/null +++ b/src/llamafactory/data/processors/pretrain.py @@ -0,0 +1,36 @@ +from itertools import chain +from typing import TYPE_CHECKING, Any, Dict, List + + +if TYPE_CHECKING: + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + + +def preprocess_pretrain_dataset( + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" +) -> Dict[str, List[List[int]]]: + # build grouped texts with format `X1 X2 X3 ...` if packing is enabled + text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] + + if not data_args.packing: + if data_args.template == "gemma": + text_examples = [tokenizer.bos_token + example for example in text_examples] + + result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) + else: + tokenized_examples = tokenizer(text_examples, add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.cutoff_len + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + if data_args.template == "gemma": + for i in range(len(result["input_ids"])): + result["input_ids"][i][0] = tokenizer.bos_token_id + + return result diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py new file mode 100644 index 00000000..80326d98 --- /dev/null +++ b/src/llamafactory/data/processors/supervised.py @@ -0,0 +1,137 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN +from ...extras.logging import get_logger +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_supervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + messages = examples["prompt"][i] + examples["response"][i] + input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + + for turn_idx, (source_ids, target_ids) in enumerate( + template.encode_multiturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + + return model_inputs + + +def preprocess_packed_supervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + input_ids, labels = [], [] + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + messages = examples["prompt"][i] + examples["response"][i] + for source_ids, target_ids in template.encode_multiturn( + tokenizer, messages, examples["system"][i], examples["tools"][i] + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif len(input_ids) != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + total_length = len(input_ids) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + for i in range(0, total_length, block_size): + if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): + model_inputs["input_ids"].append(input_ids[i : i + block_size]) + model_inputs["attention_mask"].append([1] * block_size) + model_inputs["labels"].append(labels[i : i + block_size]) + + return model_inputs + + +def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py new file mode 100644 index 00000000..4adf4f61 --- /dev/null +++ b/src/llamafactory/data/processors/unsupervised.py @@ -0,0 +1,76 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IMAGE_TOKEN +from ...extras.logging import get_logger +from ..utils import Role +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_unsupervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + if len(examples["response"][i]) == 1: + messages = examples["prompt"][i] + examples["response"][i] + else: + messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] + + input_ids, labels = template.encode_oneturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + + if template.efficient_eos: + labels += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + + return model_inputs + + +def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 66e9dca5..bf7133a9 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl slot_items.append(placeholder) if slot_pieces[1]: slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") - elif isinstance(slot, set): - if "bos_token" in slot: + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: slot_items.append("'" + tokenizer.bos_token + "'") - elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced + elif "eos_token" in slot and tokenizer.eos_token_id is not None: slot_items.append("'" + tokenizer.eos_token + "'") elif isinstance(slot, dict): raise ValueError("Dict is not supported.") @@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" jinja_template += "{% set content = " + system_message + " + message['content'] %}" jinja_template += "{% endif %}" + jinja_template += "{% if message['role'] == 'user' %}" user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) jinja_template += "{{ " + user_message + " }}" + jinja_template += "{% elif message['role'] == 'assistant' %}" assistant_message = _convert_slots_to_jinja( template.format_assistant.apply() + template.format_separator.apply(), tokenizer @@ -614,6 +616,9 @@ _register_template( name="empty", format_user=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + efficient_eos=True, + force_system=True, ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index fecf0c38..ae088e66 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -22,6 +22,8 @@ FILEEXT2TYPE = { IGNORE_INDEX = -100 +IMAGE_TOKEN = "" + LAYERNORM_NAMES = {"norm", "ln"} METHODS = ["full", "freeze", "lora"] @@ -714,6 +716,28 @@ register_model_group( ) +register_model_group( + models={ + "PaliGemma-3B-pt-224": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", + }, + "PaliGemma-3B-pt-448": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", + }, + "PaliGemma-3B-pt-896": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", + }, + "PaliGemma-3B-mix-224": { + DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", + }, + "PaliGemma-3B-mix-448": { + DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", + }, + }, + vision=True, +) + + register_model_group( models={ "Phi-1.5-1.3B": { diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 0addf315..0dc07d28 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -65,7 +65,7 @@ def check_dependencies() -> None: require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") - require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") + require_version("trl>=0.8.2", "To fix: pip install trl>=0.8.2") def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 5885bb09..650d1c22 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -145,7 +145,7 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: str = field( + export_device: Literal["cpu", "cuda"] = field( default="cpu", metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 20f9a003..6311297e 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) - if model_args.export_dir is not None: - model_args.device_map = {"": torch.device(model_args.export_device)} + if model_args.export_dir is not None and model_args.export_device == "cpu": + model_args.device_map = {"": torch.device("cpu")} else: model_args.device_map = "auto" diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 519e95f1..23aa2c8a 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -4,7 +4,7 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch -from transformers import BatchEncoding, Trainer +from transformers import Trainer from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model @@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer): Otherwise the average log probabilities. """ - batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error - - all_logits: "torch.Tensor" = model( - input_ids=batch_copied["input_ids"], - attention_mask=batch_copied["attention_mask"], - return_dict=True, - use_cache=False, - ).logits.to(torch.float32) + batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error + all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps = self.get_batch_logps( logits=all_logits, diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 5578c50c..b0e42406 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer): self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: with torch.no_grad(): - kl_logits = model( - input_ids=batch["kl_input_ids"], - attention_mask=batch["kl_attention_mask"], - return_dict=True, - use_cache=False, - ).logits.to(torch.float32) + kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]} + if "pixel_values" in batch: + kl_model_inputs["pixel_values"] = batch["pixel_values"] - target_logits = model( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - return_dict=True, - use_cache=False, - ).logits.to(torch.float32) + if "kl_token_type_ids" in batch: + kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"] + + kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) + + model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} + if "pixel_values" in batch: + model_inputs["pixel_values"] = batch["pixel_values"] + + if "token_type_ids" in batch: + model_inputs["token_type_ids"] = batch["token_type_ids"] + + target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) target_logps = self.get_batch_logps( logits=target_logits, diff --git a/src/llamafactory/train/orpo/trainer.py b/src/llamafactory/train/orpo/trainer.py index 1b743647..7cfdb429 100644 --- a/src/llamafactory/train/orpo/trainer.py +++ b/src/llamafactory/train/orpo/trainer.py @@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer): r""" Computes the average log probabilities of the labels under the given logits. """ - all_logits: "torch.Tensor" = model( - input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False - ).logits.to(torch.float32) + all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps = self.get_batch_logps( logits=all_logits, diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index be853604..9b48c89a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as rlhf_tab: with gr.Row(): - dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) - dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) - orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) + pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) + pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) + pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid") reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True) + with gr.Column(): + ppo_score_norm = gr.Checkbox() + ppo_whiten_rewards = gr.Checkbox() - input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) + input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards}) elem_dict.update( - dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model) + dict( + rlhf_tab=rlhf_tab, + pref_beta=pref_beta, + pref_ftx=pref_ftx, + pref_loss=pref_loss, + reward_model=reward_model, + ppo_score_norm=ppo_score_norm, + ppo_whiten_rewards=ppo_whiten_rewards, + ) ) with gr.Accordion(open=False) as galore_tab: diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 7afe6ec3..bd4a4205 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -774,52 +774,52 @@ LOCALES = { "label": "RLHF 参数设置", }, }, - "dpo_beta": { + "pref_beta": { "en": { - "label": "DPO beta", - "info": "Value of the beta parameter in the DPO loss.", + "label": "Beta value", + "info": "Value of the beta parameter in the loss.", }, "ru": { - "label": "DPO бета", - "info": "Значение параметра бета в функции потерь DPO.", + "label": "Бета значение", + "info": "Значение параметра бета в функции потерь.", }, "zh": { - "label": "DPO beta 参数", - "info": "DPO 损失函数中 beta 超参数大小。", + "label": "Beta 参数", + "info": "损失函数中 beta 超参数大小。", }, }, - "dpo_ftx": { + "pref_ftx": { "en": { - "label": "DPO-ftx weight", - "info": "The weight of SFT loss in the DPO-ftx.", + "label": "Ftx gamma", + "info": "The weight of SFT loss in the final loss.", }, "ru": { - "label": "Вес DPO-ftx", - "info": "Вес функции потерь SFT в DPO-ftx.", + "label": "Ftx гамма", + "info": "Вес потери SFT в итоговой потере.", }, "zh": { - "label": "DPO-ftx 权重", - "info": "DPO-ftx 中 SFT 损失的权重大小。", + "label": "Ftx gamma", + "info": "损失函数中 SFT 损失的权重大小。", }, }, - "orpo_beta": { + "pref_loss": { "en": { - "label": "ORPO beta", - "info": "Value of the beta parameter in the ORPO loss.", + "label": "Loss type", + "info": "The type of the loss function.", }, "ru": { - "label": "ORPO бета", - "info": "Значение параметра бета в функции потерь ORPO.", + "label": "Тип потерь", + "info": "Тип функции потерь.", }, "zh": { - "label": "ORPO beta 参数", - "info": "ORPO 损失函数中 beta 超参数大小。", + "label": "损失类型", + "info": "损失函数的类型。", }, }, "reward_model": { "en": { "label": "Reward model", - "info": "Adapter of the reward model for PPO training.", + "info": "Adapter of the reward model in PPO training.", }, "ru": { "label": "Модель вознаграждения", @@ -830,6 +830,34 @@ LOCALES = { "info": "PPO 训练中奖励模型的适配器路径。", }, }, + "ppo_score_norm": { + "en": { + "label": "Score norm", + "info": "Normalizing scores in PPO training.", + }, + "ru": { + "label": "Норма оценок", + "info": "Нормализация оценок в тренировке PPO.", + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中归一化奖励分数。", + }, + }, + "ppo_whiten_rewards": { + "en": { + "label": "Whiten rewards", + "info": "Whiten the rewards in PPO training.", + }, + "ru": { + "label": "Белые вознаграждения", + "info": "Осветлите вознаграждения в обучении PPO.", + }, + "zh": { + "label": "白化奖励", + "info": "PPO 训练中将奖励分数做白化处理。", + }, + }, "galore_tab": { "en": { "label": "GaLore configurations", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index ef911a16..24046e62 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -145,11 +145,14 @@ class Runner: plot_loss=True, ) + # freeze config if args["finetuning_type"] == "freeze": args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") args["freeze_trainable_modules"] = get("train.freeze_trainable_modules") args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None - elif args["finetuning_type"] == "lora": + + # lora config + if args["finetuning_type"] == "lora": args["lora_rank"] = get("train.lora_rank") args["lora_alpha"] = get("train.lora_alpha") args["lora_dropout"] = get("train.lora_dropout") @@ -163,6 +166,7 @@ class Runner: if args["use_llama_pro"]: args["num_layer_trainable"] = get("train.num_layer_trainable") + # rlhf config if args["stage"] == "ppo": args["reward_model"] = ",".join( [ @@ -171,31 +175,41 @@ class Runner: ] ) args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" + args["ppo_score_norm"] = get("train.ppo_score_norm") + args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards") + args["top_k"] = 0 + args["top_p"] = 0.9 elif args["stage"] == "dpo": - args["dpo_beta"] = get("train.dpo_beta") - args["dpo_ftx"] = get("train.dpo_ftx") + args["dpo_beta"] = get("train.pref_beta") + args["dpo_ftx"] = get("train.pref_ftx") + args["dpo_loss"] = get("train.pref_loss") + elif args["stage"] == "kto": + args["kto_beta"] = get("train.pref_beta") + args["kto_ftx"] = get("train.pref_ftx") elif args["stage"] == "orpo": - args["orpo_beta"] = get("train.orpo_beta") - - if get("train.val_size") > 1e-6 and args["stage"] != "ppo": - args["val_size"] = get("train.val_size") - args["evaluation_strategy"] = "steps" - args["eval_steps"] = args["save_steps"] - args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] - args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"] + args["orpo_beta"] = get("train.pref_beta") + # galore config if args["use_galore"]: args["galore_rank"] = get("train.galore_rank") args["galore_update_interval"] = get("train.galore_update_interval") args["galore_scale"] = get("train.galore_scale") args["galore_target"] = get("train.galore_target") + # badam config if args["use_badam"]: args["badam_mode"] = get("train.badam_mode") args["badam_switch_mode"] = get("train.badam_switch_mode") args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") + # eval config + if get("train.val_size") > 1e-6 and args["stage"] != "ppo": + args["val_size"] = get("train.val_size") + args["evaluation_strategy"] = "steps" + args["eval_steps"] = args["save_steps"] + args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] + return args def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: