support image input in api #3971 #4061

This commit is contained in:
hiyouga 2024-06-06 02:29:55 +08:00
parent dc4a00dd63
commit 946f601136
4 changed files with 49 additions and 8 deletions

View File

@ -456,6 +456,9 @@ docker compose -f ./docker-compose.yml up -d
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.

View File

@ -454,6 +454,9 @@ docker compose -f ./docker-compose.yml up -d
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
### 从魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。

View File

@ -1,10 +1,11 @@
import json
import os
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available
from ..extras.packages import is_fastapi_available, is_pillow_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
@ -25,7 +26,14 @@ if is_fastapi_available():
from fastapi import HTTPException, status
if is_pillow_available():
import requests
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@ -40,7 +48,9 @@ ROLE_MAPPING = {
}
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
@ -49,12 +59,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = ""
system = None
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@ -66,6 +77,18 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if os.path.isfile(image_url):
image_path = open(image_url, "rb")
else:
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@ -76,9 +99,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
tools = None
return input_messages, system, tools
return input_messages, system, tools, image
def _create_stream_chat_completion_chunk(
@ -97,11 +120,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request)
input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@ -145,7 +169,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request)
input_messages, system, tools, image = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -159,6 +183,7 @@ async def create_stream_chat_completion_response(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,

View File

@ -56,9 +56,19 @@ class FunctionCall(BaseModel):
function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel):
role: Role
content: Optional[str] = None
content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None