diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 91ac2949..ac7ac769 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger from ..extras.misc import get_device_count -from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5 +from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 from ..model import load_config, load_tokenizer from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM @@ -29,7 +29,9 @@ if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest - if is_vllm_version_greater_than_0_5(): + if is_vllm_version_greater_than_0_5_1(): + pass + elif is_vllm_version_greater_than_0_5(): from vllm.multimodal.image import ImagePixelData else: from vllm.sequence import MultiModalData @@ -130,8 +132,10 @@ class VllmEngine(BaseEngine): if self.processor is not None and image is not None: # add image features image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] - if is_vllm_version_greater_than_0_5(): - multi_modal_data = ImagePixelData(image=pixel_values) + if is_vllm_version_greater_than_0_5_1(): + multi_modal_data = {"image": pixel_values} + elif is_vllm_version_greater_than_0_5(): + multi_modal_data = ImagePixelData(image=pixel_values) else: # TODO: remove vllm 0.4.3 support multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) else: diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 0a84a293..1ba4cfbb 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -81,3 +81,7 @@ def is_vllm_available(): @lru_cache def is_vllm_version_greater_than_0_5(): return _get_package_version("vllm") >= version.parse("0.5.0") + +@lru_cache +def is_vllm_version_greater_than_0_5_1(): + return _get_package_version("vllm") >= version.parse("0.5.1") \ No newline at end of file