clean code

This commit is contained in:
hiyouga 2024-06-13 01:58:16 +08:00
parent 1f23f25226
commit 2ed8270112
4 changed files with 17 additions and 27 deletions

View File

@ -1,12 +1,10 @@
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from packaging import version
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, _get_package_version from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -16,7 +14,7 @@ if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if _get_package_version("vllm") >= version.parse("0.5.0"): if is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData from vllm.multimodal.image import ImagePixelData
else: else:
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
@ -112,9 +110,9 @@ class VllmEngine(BaseEngine):
if self.processor is not None and image is not None: # add image features if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if _get_package_version("vllm") >= version.parse("0.5.0"): if is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(pixel_values) multi_modal_data = ImagePixelData(image=pixel_values)
else: else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else: else:
multi_modal_data = None multi_modal_data = None

View File

@ -1,5 +1,6 @@
import importlib.metadata import importlib.metadata
import importlib.util import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from packaging import version from packaging import version
@ -24,10 +25,6 @@ def is_fastapi_available():
return _is_package_available("fastapi") return _is_package_available("fastapi")
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available(): def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
@ -36,18 +33,10 @@ def is_gradio_available():
return _is_package_available("gradio") return _is_package_available("gradio")
def is_jieba_available():
return _is_package_available("jieba")
def is_matplotlib_available(): def is_matplotlib_available():
return _is_package_available("matplotlib") return _is_package_available("matplotlib")
def is_nltk_available():
return _is_package_available("nltk")
def is_pillow_available(): def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
@ -60,10 +49,6 @@ def is_rouge_available():
return _is_package_available("rouge_chinese") return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available(): def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
@ -74,3 +59,8 @@ def is_uvicorn_available():
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")

View File

@ -1,7 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING: if TYPE_CHECKING:
@ -21,13 +22,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager" requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
if not is_sdpa_available(): if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.") logger.warning("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2": elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available(): if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.") logger.warning("FlashAttention-2 is not installed.")
return return

View File

@ -2,9 +2,10 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np import numpy as np
from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available from ...extras.packages import is_rouge_available
if TYPE_CHECKING: if TYPE_CHECKING: