rm some
This commit is contained in:
parent
68cdd9a020
commit
eefcd105c1
|
@ -1,42 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForVis2Seq:
|
||||
processor: AutoProcessor
|
||||
|
||||
def __call__(self, examples):
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
if len(example["images"]) > 1:
|
||||
raise ValueError("This collator only supports one image per example")
|
||||
messages = example["messages"]
|
||||
text = self.processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
texts.append(text)
|
||||
images.append(example["images"][0])
|
||||
|
||||
batch = self.processor(
|
||||
text=texts, images=images, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone()
|
||||
if self.processor.tokenizer.pad_token_id is not None:
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForMLLM:
|
||||
processor: AutoProcessor
|
||||
|
||||
def __call__(self, examples):
|
||||
print(examples[0].keys())
|
||||
print(examples[0]["input_ids"])
|
||||
batch = {}
|
||||
return batch
|
Loading…
Reference in New Issue