diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index d2850a6e..34126adf 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -1,10 +1,12 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union +from packaging import version + from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger from ..extras.misc import get_device_count -from ..extras.packages import is_vllm_available +from ..extras.packages import is_vllm_available, _get_package_version from ..model import load_config, load_tokenizer from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response @@ -14,10 +16,10 @@ if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest - try: - from vllm.multimodal import MultiModalData # type: ignore (for vllm>=0.5.0) - except ImportError: - from vllm.sequence import MultiModalData # for vllm<0.5.0 + if _get_package_version("vllm") >= version.parse("0.5.0"): + from vllm.multimodal.image import ImagePixelData + else: + from vllm.sequence import MultiModalData if TYPE_CHECKING: @@ -110,7 +112,10 @@ class VllmEngine(BaseEngine): if self.processor is not None and image is not None: # add image features image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + if _get_package_version("vllm") >= version.parse("0.5.0"): + multi_modal_data = ImagePixelData(pixel_values) + else: + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) else: multi_modal_data = None