diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index f59029a1..5f0d02a7 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -61,7 +61,7 @@ class HuggingfaceEngine(BaseEngine): and image is not None and not hasattr(processor, "image_seq_length") and IMAGE_TOKEN not in messages[0]["content"] - ): # llava case + ): # llava-like models messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] @@ -74,7 +74,7 @@ class HuggingfaceEngine(BaseEngine): image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") batch_feature = image_processor(image, return_tensors="pt") pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W) - if hasattr(processor, "image_seq_length"): # paligemma case + if hasattr(processor, "image_seq_length"): # paligemma models image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 31d03fbe..e424481f 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -98,7 +98,7 @@ class VllmEngine(BaseEngine): and image is not None and not hasattr(self.processor, "image_seq_length") and IMAGE_TOKEN not in messages[0]["content"] - ): # llava case + ): # llava-like models messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 474d6a30..1dc8dd8d 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Dict, Sequence import torch from transformers import DataCollatorForSeq2Seq @@ -11,21 +11,6 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): Data collator for pairwise data. """ - def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: - r""" - Masks out the input ids except for the responses. - """ - padded_labels = [] - for feature, (prompt_len, answer_len) in zip(batch, positions): - if self.tokenizer.padding_side == "left": - start, end = feature.size(0) - answer_len, feature.size(0) - else: - start, end = prompt_len, prompt_len + answer_len - padded_tensor = self.label_pad_token_id * torch.ones_like(feature) - padded_tensor[start:end] = feature[start:end] - padded_labels.append(padded_tensor) - return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: r""" Pads batched data to the longest sequence in the batch. @@ -34,21 +19,22 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): the last n examples represent rejected examples. """ concatenated_features = [] - label_positions = [] - for key in ("chosen_ids", "rejected_ids"): + for key in ("chosen", "rejected"): for feature in features: - prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) - concatenated_features.append( - { - "input_ids": feature["prompt_ids"] + feature[key], - "attention_mask": [1] * (prompt_len + answer_len), - } - ) - label_positions.append((prompt_len, answer_len)) + target_feature = { + "input_ids": feature["{}_input_ids".format(key)], + "attention_mask": feature["{}_attention_mask".format(key)], + "labels": feature["{}_labels".format(key)], + } + if "pixel_values" in feature: + target_feature["pixel_values"] = feature["pixel_values"] - batch = super().__call__(concatenated_features) - batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) - return batch + if "{}_token_type_ids".format(key) in feature: + target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] + + concatenated_features.append(target_feature) + + return super().__call__(concatenated_features) @dataclass @@ -62,20 +48,25 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): kl_features = [] kto_tags = [] for feature in features: - target_features.append( - { - "input_ids": feature["input_ids"], - "attention_mask": feature["attention_mask"], - "labels": feature["labels"], - } - ) - kl_features.append( - { - "input_ids": feature["kl_input_ids"], - "attention_mask": feature["kl_attention_mask"], - "labels": feature["kl_labels"], - } - ) + target_feature = { + "input_ids": feature["input_ids"], + "attention_mask": feature["attention_mask"], + "labels": feature["labels"], + } + kl_feature = { + "input_ids": feature["kl_input_ids"], + "attention_mask": feature["kl_attention_mask"], + "labels": feature["kl_labels"], + } + if "pixel_values" in feature: + target_feature["pixel_values"] = feature["pixel_values"] + + if "token_type_ids" in feature: + target_feature["token_type_ids"] = feature["token_type_ids"] + kl_feature["token_type_ids"] = feature["kl_token_type_ids"] + + target_features.append(target_feature) + kl_features.append(kl_feature) kto_tags.append(feature["kto_tags"]) batch = super().__call__(target_features) @@ -83,5 +74,8 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] + if "token_type_ids" in batch: + batch["kl_token_type_ids"] = kl_batch["token_type_ids"] + batch["kto_tags"] = torch.tensor(kto_tags) return batch diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index bed694a2..48d28f1d 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -1,5 +1,6 @@ import inspect import os +import sys from typing import TYPE_CHECKING, Literal, Optional, Union from datasets import load_dataset, load_from_disk @@ -167,12 +168,15 @@ def get_dataset( logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) - exit(0) + sys.exit(0) if training_args.should_log: try: print_function(next(iter(dataset))) except StopIteration: - raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") return dataset diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 4123645f..336257ca 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -1,406 +1,25 @@ from functools import partial -from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple -from ..extras.constants import IGNORE_INDEX, IMAGE_TOKEN -from ..extras.logging import get_logger -from ..extras.packages import is_pillow_available -from .utils import Role - - -if is_pillow_available(): - from PIL import Image +from .processors.feedback import preprocess_feedback_dataset +from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example +from .processors.pretrain import preprocess_pretrain_dataset +from .processors.supervised import ( + preprocess_packed_supervised_dataset, + preprocess_supervised_dataset, + print_supervised_dataset_example, +) +from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example if TYPE_CHECKING: - from numpy.typing import NDArray - from PIL.Image import Image as ImageObject from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils import PreTrainedTokenizer from ..hparams import DataArguments from .template import Template -logger = get_logger(__name__) - - -def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": - # process visual inputs (currently only supports a single image) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) - return image_processor(image, return_tensors="pt")["pixel_values"][0] - - -def preprocess_pretrain_dataset( - examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" -) -> Dict[str, List[List[int]]]: - # build grouped texts with format `X1 X2 X3 ...` if packing is enabled - text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] - - if not data_args.packing: - if data_args.template == "gemma": - text_examples = [tokenizer.bos_token + example for example in text_examples] - - result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) - else: - tokenized_examples = tokenizer(text_examples, add_special_tokens=False) - concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} - total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) - block_size = data_args.cutoff_len - total_length = (total_length // block_size) * block_size - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - if data_args.template == "gemma": - for i in range(len(result["input_ids"])): - result["input_ids"][i][0] = tokenizer.bos_token_id - - return result - - -def preprocess_supervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build inputs with format ` X Y ` and labels with format ` ... Y ` - # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"] = [] - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None and not hasattr(processor, "image_seq_length"): # llava models - examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] - - messages = examples["prompt"][i] + examples["response"][i] - input_ids, labels = [], [] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - input_ids += [image_token_id] * getattr(processor, "image_seq_length") - labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - - for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - if hasattr(processor, "image_seq_length"): # paligemma models - token_type_ids = [0] * getattr(processor, "image_seq_length") - token_type_ids += [1] * (len(input_ids) - getattr(processor, "image_seq_length")) - model_inputs["token_type_ids"].append(token_type_ids) - - return model_inputs - - -def preprocess_packed_supervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build inputs with format ` X1 Y1 X2 Y2 ` - # and labels with format ` ... Y1 ... Y2 ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - input_ids, labels = [], [] - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - messages = examples["prompt"][i] + examples["response"][i] - for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) - - return model_inputs - - -def preprocess_unsupervised_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build inputs with format ` X` and labels with format `Y ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"] = [] - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None and not hasattr(processor, "image_seq_length"): # llava models - examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] - - if len(examples["response"][i]) == 1: - messages = examples["prompt"][i] + examples["response"][i] - else: - messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] - - input_ids, labels = template.encode_oneturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - labels += [tokenizer.eos_token_id] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids - - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - return model_inputs - - -def preprocess_pairwise_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # build input pairs with format ` X`, `Y1 ` and `Y2 ` - model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None and not hasattr(processor, "image_seq_length"): # llava case - examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] - - chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] - rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] - prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, - chosen_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - _, rejected_ids = template.encode_oneturn( - tokenizer, - rejected_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - chosen_ids += [tokenizer.eos_token_id] - rejected_ids += [tokenizer.eos_token_id] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids - - model_inputs["prompt_ids"].append(prompt_ids) - model_inputs["chosen_ids"].append(chosen_ids) - model_inputs["rejected_ids"].append(rejected_ids) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - return model_inputs - - -def preprocess_kto_dataset( - examples: Dict[str, List[Any]], - template: "Template", - tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], - data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: - # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs - kl_response = examples["response"][::-1] - model_inputs = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "kl_input_ids": [], - "kl_attention_mask": [], - "kl_labels": [], - "kto_tags": [], - } - if processor is not None: - model_inputs["pixel_values"] = [] - preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) - - for i in range(len(examples["prompt"])): - if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) - continue - - if processor is not None and not hasattr(processor, "image_seq_length"): # llava case - examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] - - if examples["response"][i][0]["content"]: # desired example - kto_tag = True - messages = examples["prompt"][i] + [examples["response"][i][0]] - else: # undesired example - kto_tag = False - messages = examples["prompt"][i] + [examples["response"][i][1]] - - if kl_response[i][0]["content"]: - kl_messages = examples["prompt"][i] + [kl_response[i][0]] - else: - kl_messages = examples["prompt"][i] + [kl_response[i][1]] - - prompt_ids, response_ids = template.encode_oneturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - _, kl_response_ids = template.encode_oneturn( - tokenizer, - kl_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - response_ids += [tokenizer.eos_token_id] - kl_response_ids += [tokenizer.eos_token_id] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids - - input_ids = prompt_ids + response_ids - labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids - kl_input_ids = prompt_ids + kl_response_ids - kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids - model_inputs["input_ids"].append(input_ids) - model_inputs["attention_mask"].append([1] * len(input_ids)) - model_inputs["labels"].append(labels) - model_inputs["kl_input_ids"].append(kl_input_ids) - model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) - model_inputs["kl_labels"].append(kl_labels) - model_inputs["kto_tags"].append(kto_tag) - if processor is not None: - model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) - - desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) - undesirable_num = len(model_inputs["kto_tags"]) - desirable_num - if desirable_num == 0 or undesirable_num == 0: - logger.warning("Your dataset only has one preference type.") - - return model_inputs - - -def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - print("label_ids:\n{}".format(example["labels"])) - print( - "labels:\n{}".format( - tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) - ) - ) - - -def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("prompt_ids:\n{}".format(example["prompt_ids"])) - print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) - print("chosen_ids:\n{}".format(example["chosen_ids"])) - print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) - print("rejected_ids:\n{}".format(example["rejected_ids"])) - print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) - - -def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: - print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) - - def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", @@ -445,7 +64,7 @@ def get_preprocess_and_print_func( print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) elif stage == "kto": preprocess_func = partial( - preprocess_kto_dataset, + preprocess_feedback_dataset, template=template, tokenizer=tokenizer, processor=processor, diff --git a/src/llamafactory/data/processors/__init__.py b/src/llamafactory/data/processors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py new file mode 100644 index 00000000..51db3e26 --- /dev/null +++ b/src/llamafactory/data/processors/feedback.py @@ -0,0 +1,110 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN +from ...extras.logging import get_logger +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_feedback_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs + kl_response = examples["response"][::-1] + model_inputs = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "kl_input_ids": [], + "kl_attention_mask": [], + "kl_labels": [], + "kto_tags": [], + } + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + model_inputs["kl_token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + if examples["response"][i][0]["content"]: # desired example + kto_tag = True + messages = examples["prompt"][i] + [examples["response"][i][0]] + else: # undesired example + kto_tag = False + messages = examples["prompt"][i] + [examples["response"][i][1]] + + if kl_response[i][0]["content"]: + kl_messages = examples["prompt"][i] + [kl_response[i][0]] + else: + kl_messages = examples["prompt"][i] + [kl_response[i][1]] + + prompt_ids, response_ids = template.encode_oneturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + _, kl_response_ids = template.encode_oneturn( + tokenizer, + kl_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + + if template.efficient_eos: + response_ids += [tokenizer.eos_token_id] + kl_response_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + input_ids = prompt_ids + response_ids + labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids + kl_input_ids = prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["kl_input_ids"].append(kl_input_ids) + model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) + model_inputs["kl_labels"].append(kl_labels) + model_inputs["kto_tags"].append(kto_tag) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor)) + + desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) + undesirable_num = len(model_inputs["kto_tags"]) - desirable_num + if desirable_num == 0 or undesirable_num == 0: + logger.warning("Your dataset only has one preference type.") + + return model_inputs diff --git a/src/llamafactory/data/processors/mm_utils.py b/src/llamafactory/data/processors/mm_utils.py new file mode 100644 index 00000000..abc7c4b2 --- /dev/null +++ b/src/llamafactory/data/processors/mm_utils.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING, List, Sequence + +from ...extras.packages import is_pillow_available + + +if is_pillow_available(): + from PIL import Image + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from PIL.Image import Image as ImageObject + from transformers import ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + +def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": + # process visual inputs (currently only supports a single image) + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) + return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W) + + +def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]: + # get paligemma token type ids for computing loss + image_seq_length = getattr(processor, "image_seq_length") + return [0] * image_seq_length + [1] * (input_len - image_seq_length) diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py new file mode 100644 index 00000000..ec0fb96e --- /dev/null +++ b/src/llamafactory/data/processors/pairwise.py @@ -0,0 +1,109 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN +from ...extras.logging import get_logger +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = { + "chosen_input_ids": [], + "chosen_attention_mask": [], + "chosen_labels": [], + "rejected_input_ids": [], + "rejected_attention_mask": [], + "rejected_labels": [], + } + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["chosen_token_type_ids"] = [] + model_inputs["rejected_token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] + rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] + prompt_ids, chosen_ids = template.encode_oneturn( + tokenizer, + chosen_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + _, rejected_ids = template.encode_oneturn( + tokenizer, + rejected_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["chosen_token_type_ids"].append( + get_paligemma_token_type_ids(len(chosen_input_ids), processor) + ) + model_inputs["rejected_token_type_ids"].append( + get_paligemma_token_type_ids(len(rejected_input_ids), processor) + ) + + return model_inputs + + +def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) + valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) + print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) + print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) + print("chosen_label_ids:\n{}".format(example["chosen_labels"])) + print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) + print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) + print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) + print("rejected_label_ids:\n{}".format(example["rejected_labels"])) + print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py new file mode 100644 index 00000000..3de0d1ac --- /dev/null +++ b/src/llamafactory/data/processors/pretrain.py @@ -0,0 +1,36 @@ +from itertools import chain +from typing import TYPE_CHECKING, Any, Dict, List + + +if TYPE_CHECKING: + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + + +def preprocess_pretrain_dataset( + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" +) -> Dict[str, List[List[int]]]: + # build grouped texts with format `X1 X2 X3 ...` if packing is enabled + text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] + + if not data_args.packing: + if data_args.template == "gemma": + text_examples = [tokenizer.bos_token + example for example in text_examples] + + result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) + else: + tokenized_examples = tokenizer(text_examples, add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.cutoff_len + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + if data_args.template == "gemma": + for i in range(len(result["input_ids"])): + result["input_ids"][i][0] = tokenizer.bos_token_id + + return result diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py new file mode 100644 index 00000000..80326d98 --- /dev/null +++ b/src/llamafactory/data/processors/supervised.py @@ -0,0 +1,137 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN +from ...extras.logging import get_logger +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_supervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + messages = examples["prompt"][i] + examples["response"][i] + input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + + for turn_idx, (source_ids, target_ids) in enumerate( + template.encode_multiturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + + return model_inputs + + +def preprocess_packed_supervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + input_ids, labels = [], [] + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + messages = examples["prompt"][i] + examples["response"][i] + for source_ids, target_ids in template.encode_multiturn( + tokenizer, messages, examples["system"][i], examples["tools"][i] + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif len(input_ids) != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + total_length = len(input_ids) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + for i in range(0, total_length, block_size): + if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): + model_inputs["input_ids"].append(input_ids[i : i + block_size]) + model_inputs["attention_mask"].append([1] * block_size) + model_inputs["labels"].append(labels[i : i + block_size]) + + return model_inputs + + +def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py new file mode 100644 index 00000000..4adf4f61 --- /dev/null +++ b/src/llamafactory/data/processors/unsupervised.py @@ -0,0 +1,76 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from ...extras.constants import IMAGE_TOKEN +from ...extras.logging import get_logger +from ..utils import Role +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.tokenization_utils import PreTrainedTokenizer + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def preprocess_unsupervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] + + if len(examples["response"][i]) == 1: + messages = examples["prompt"][i] + examples["response"][i] + else: + messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] + + input_ids, labels = template.encode_oneturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + + if template.efficient_eos: + labels += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + + return model_inputs + + +def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 519e95f1..23aa2c8a 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -4,7 +4,7 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch -from transformers import BatchEncoding, Trainer +from transformers import Trainer from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model @@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer): Otherwise the average log probabilities. """ - batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error - - all_logits: "torch.Tensor" = model( - input_ids=batch_copied["input_ids"], - attention_mask=batch_copied["attention_mask"], - return_dict=True, - use_cache=False, - ).logits.to(torch.float32) + batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error + all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps = self.get_batch_logps( logits=all_logits, diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 5578c50c..b0e42406 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer): self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: with torch.no_grad(): - kl_logits = model( - input_ids=batch["kl_input_ids"], - attention_mask=batch["kl_attention_mask"], - return_dict=True, - use_cache=False, - ).logits.to(torch.float32) + kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]} + if "pixel_values" in batch: + kl_model_inputs["pixel_values"] = batch["pixel_values"] - target_logits = model( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - return_dict=True, - use_cache=False, - ).logits.to(torch.float32) + if "kl_token_type_ids" in batch: + kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"] + + kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) + + model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} + if "pixel_values" in batch: + model_inputs["pixel_values"] = batch["pixel_values"] + + if "token_type_ids" in batch: + model_inputs["token_type_ids"] = batch["token_type_ids"] + + target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) target_logps = self.get_batch_logps( logits=target_logits, diff --git a/src/llamafactory/train/orpo/trainer.py b/src/llamafactory/train/orpo/trainer.py index 1b743647..7cfdb429 100644 --- a/src/llamafactory/train/orpo/trainer.py +++ b/src/llamafactory/train/orpo/trainer.py @@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer): r""" Computes the average log probabilities of the labels under the given logits. """ - all_logits: "torch.Tensor" = model( - input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False - ).logits.to(torch.float32) + all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps = self.get_batch_logps( logits=all_logits,