fix paligemma sft

requires transformers>=4.41.1
This commit is contained in:
hiyouga 2024-05-24 00:23:40 +08:00
parent 67ebc7b388
commit de0e67aff1
1 changed files with 12 additions and 4 deletions

View File

@ -74,19 +74,21 @@ def preprocess_supervised_dataset(
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava case
if processor is not None and not hasattr(processor, "image_seq_length"): # llava models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
@ -120,6 +122,10 @@ def preprocess_supervised_dataset(
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
if hasattr(processor, "image_seq_length"): # paligemma models
token_type_ids = [0] * getattr(processor, "image_seq_length")
token_type_ids += [1] * (len(input_ids) - getattr(processor, "image_seq_length"))
model_inputs["token_type_ids"].append(token_type_ids)
return model_inputs
@ -183,13 +189,15 @@ def preprocess_unsupervised_dataset(
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava case
if processor is not None and not hasattr(processor, "image_seq_length"): # llava models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
@ -209,7 +217,7 @@ def preprocess_unsupervised_dataset(
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids