fix paligemma data preprocess
This commit is contained in:
parent
542229abb3
commit
e55c85ac72
|
@ -2,7 +2,7 @@ from functools import partial
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX, IMAGE_TOKEN
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_pillow_available
|
from ..extras.packages import is_pillow_available
|
||||||
from .utils import Role
|
from .utils import Role
|
||||||
|
@ -80,11 +80,17 @@ def preprocess_supervised_dataset(
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None and not hasattr(processor, "image_seq_length"): # llava case
|
||||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
|
|
||||||
|
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case
|
||||||
|
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
|
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||||
|
labels += [image_token_id] * getattr(processor, "image_seq_length")
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
template.encode_multiturn(
|
template.encode_multiturn(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -183,8 +189,8 @@ def preprocess_unsupervised_dataset(
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None and not hasattr(processor, "image_seq_length"): # llava case
|
||||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||||
|
|
||||||
if len(examples["response"][i]) == 1:
|
if len(examples["response"][i]) == 1:
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
@ -203,6 +209,10 @@ def preprocess_unsupervised_dataset(
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case
|
||||||
|
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
|
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
|
@ -230,8 +240,8 @@ def preprocess_pairwise_dataset(
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None and not hasattr(processor, "image_seq_length"): # llava case
|
||||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||||
|
|
||||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
@ -256,6 +266,10 @@ def preprocess_pairwise_dataset(
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case
|
||||||
|
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
|
||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
model_inputs["chosen_ids"].append(chosen_ids)
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
model_inputs["rejected_ids"].append(rejected_ids)
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
|
@ -292,8 +306,8 @@ def preprocess_kto_dataset(
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None and not hasattr(processor, "image_seq_length"): # llava case
|
||||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
|
||||||
|
|
||||||
if examples["response"][i][0]["content"]: # desired example
|
if examples["response"][i][0]["content"]: # desired example
|
||||||
kto_tag = True
|
kto_tag = True
|
||||||
|
@ -328,6 +342,10 @@ def preprocess_kto_dataset(
|
||||||
response_ids += [tokenizer.eos_token_id]
|
response_ids += [tokenizer.eos_token_id]
|
||||||
kl_response_ids += [tokenizer.eos_token_id]
|
kl_response_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case
|
||||||
|
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
|
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||||
|
|
||||||
input_ids = prompt_ids + response_ids
|
input_ids = prompt_ids + response_ids
|
||||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||||
kl_input_ids = prompt_ids + kl_response_ids
|
kl_input_ids = prompt_ids + kl_response_ids
|
||||||
|
|
Loading…
Reference in New Issue