From 946f60113630d659e7048bffbb3aa7132ac3ecd1 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 02:29:55 +0800 Subject: [PATCH] support image input in api #3971 #4061 --- README.md | 3 +++ README_zh.md | 3 +++ src/llamafactory/api/chat.py | 39 ++++++++++++++++++++++++++------ src/llamafactory/api/protocol.py | 12 +++++++++- 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5e8bc8eb..3eebf355 100644 --- a/README.md +++ b/README.md @@ -456,6 +456,9 @@ docker compose -f ./docker-compose.yml up -d CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml ``` +> [!TIP] +> Visit https://platform.openai.com/docs/api-reference/chat/create for API document. + ### Download from ModelScope Hub If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. diff --git a/README_zh.md b/README_zh.md index d8e17b29..09a7f330 100644 --- a/README_zh.md +++ b/README_zh.md @@ -454,6 +454,9 @@ docker compose -f ./docker-compose.yml up -d CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml ``` +> [!TIP] +> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。 + ### 从魔搭社区下载 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index b7a08f0b..712b6940 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -1,10 +1,11 @@ import json +import os import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole from ..extras.logging import get_logger -from ..extras.packages import is_fastapi_available +from ..extras.packages import is_fastapi_available, is_pillow_available from .common import dictify, jsonify from .protocol import ( ChatCompletionMessage, @@ -25,7 +26,14 @@ if is_fastapi_available(): from fastapi import HTTPException, status +if is_pillow_available(): + import requests + from PIL import Image + + if TYPE_CHECKING: + from numpy.typing import NDArray + from ..chat import ChatModel from .protocol import ChatCompletionRequest, ScoreEvaluationRequest @@ -40,7 +48,9 @@ ROLE_MAPPING = { } -def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: +def _process_request( + request: "ChatCompletionRequest", +) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) if len(request.messages) == 0: @@ -49,12 +59,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s if request.messages[0].role == Role.SYSTEM: system = request.messages.pop(0).content else: - system = "" + system = None if len(request.messages) % 2 == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") input_messages = [] + image = None for i, message in enumerate(request.messages): if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") @@ -66,6 +77,18 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s arguments = message.tool_calls[0].function.arguments content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) + elif isinstance(message.content, list): + for input_item in message.content: + if input_item.type == "text": + input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) + else: + image_url = input_item.image_url.url + if os.path.isfile(image_url): + image_path = open(image_url, "rb") + else: + image_path = requests.get(image_url, stream=True).raw + + image = Image.open(image_path).convert("RGB") else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) @@ -76,9 +99,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s except Exception: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: - tools = "" + tools = None - return input_messages, system, tools + return input_messages, system, tools, image def _create_stream_chat_completion_chunk( @@ -97,11 +120,12 @@ async def create_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> "ChatCompletionResponse": completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) - input_messages, system, tools = _process_request(request) + input_messages, system, tools, image = _process_request(request) responses = await chat_model.achat( input_messages, system, tools, + image, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, @@ -145,7 +169,7 @@ async def create_stream_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> AsyncGenerator[str, None]: completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) - input_messages, system, tools = _process_request(request) + input_messages, system, tools, image = _process_request(request) if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") @@ -159,6 +183,7 @@ async def create_stream_chat_completion_response( input_messages, system, tools, + image, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index 525fa6a7..055fa781 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -56,9 +56,19 @@ class FunctionCall(BaseModel): function: Function +class ImageURL(BaseModel): + url: str + + +class MultimodalInputItem(BaseModel): + type: Literal["text", "image_url"] + text: Optional[str] = None + image_url: Optional[ImageURL] = None + + class ChatMessage(BaseModel): role: Role - content: Optional[str] = None + content: Optional[Union[str, List[MultimodalInputItem]]] = None tool_calls: Optional[List[FunctionCall]] = None