fix paligemma inference

This commit is contained in:
hiyouga 2024-05-20 23:36:43 +08:00
parent 7262679666
commit 542229abb3
4 changed files with 47 additions and 20 deletions

View File

@ -8,6 +8,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine):
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" + messages[0]["content"]
if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava case
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma case
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
return gen_kwargs, prompt_length

View File

@ -2,6 +2,7 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
@ -17,7 +18,6 @@ if is_vllm_available():
if TYPE_CHECKING:
import torch
from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor
@ -67,7 +67,7 @@ class VllmEngine(BaseEngine):
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
@ -92,14 +92,28 @@ class VllmEngine(BaseEngine):
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava case
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
@ -144,13 +158,6 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True,
)
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate(
prompt=None,
sampling_params=sampling_params,

View File

@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set):
if "bos_token" in slot:
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
@ -614,6 +616,9 @@ _register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
force_system=True,
)

View File

@ -22,6 +22,8 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100
IMAGE_TOKEN = "<image>"
LAYERNORM_NAMES = {"norm", "ln"}
METHODS = ["full", "freeze", "lora"]