This commit is contained in:
BUAADreamer 2024-04-25 20:09:43 +08:00
parent 68cdd9a020
commit eefcd105c1
1 changed files with 0 additions and 42 deletions

View File

@ -1,42 +0,0 @@
from dataclasses import dataclass
from transformers import AutoProcessor
@dataclass
class DataCollatorForVis2Seq:
processor: AutoProcessor
def __call__(self, examples):
texts = []
images = []
for example in examples:
if len(example["images"]) > 1:
raise ValueError("This collator only supports one image per example")
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
batch = self.processor(
text=texts, images=images, return_tensors="pt", padding=True
)
labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
@dataclass
class DataCollatorForMLLM:
processor: AutoProcessor
def __call__(self, examples):
print(examples[0].keys())
print(examples[0]["input_ids"])
batch = {}
return batch