From e55c85ac72f4938738dbce576f83b47a1fea88ae Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 20 May 2024 23:51:32 +0800 Subject: [PATCH] fix paligemma data preprocess --- src/llamafactory/data/preprocess.py | 36 +++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 557678e6..4bc5ad3c 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -2,7 +2,7 @@ from functools import partial from itertools import chain from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple -from ..extras.constants import IGNORE_INDEX +from ..extras.constants import IGNORE_INDEX, IMAGE_TOKEN from ..extras.logging import get_logger from ..extras.packages import is_pillow_available from .utils import Role @@ -80,11 +80,17 @@ def preprocess_supervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [image_token_id] * getattr(processor, "image_seq_length") + for turn_idx, (source_ids, target_ids) in enumerate( template.encode_multiturn( tokenizer, @@ -183,8 +189,8 @@ def preprocess_unsupervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] if len(examples["response"][i]) == 1: messages = examples["prompt"][i] + examples["response"][i] @@ -203,6 +209,10 @@ def preprocess_unsupervised_dataset( if template.efficient_eos: labels += [tokenizer.eos_token_id] + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) @@ -230,8 +240,8 @@ def preprocess_pairwise_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] @@ -256,6 +266,10 @@ def preprocess_pairwise_dataset( chosen_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id] + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + model_inputs["prompt_ids"].append(prompt_ids) model_inputs["chosen_ids"].append(chosen_ids) model_inputs["rejected_ids"].append(rejected_ids) @@ -292,8 +306,8 @@ def preprocess_kto_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None: - examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] if examples["response"][i][0]["content"]: # desired example kto_tag = True @@ -328,6 +342,10 @@ def preprocess_kto_dataset( response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id] + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + input_ids = prompt_ids + response_ids labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids kl_input_ids = prompt_ids + kl_response_ids