diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 72b2ae50..7f5bf8c4 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -16,9 +16,12 @@ import base64 import io import json import os +import re import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple +import numpy as np + from ..data import Role as DataRole from ..extras.logging import get_logger from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available @@ -104,7 +107,7 @@ def _process_request( input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) else: image_url = input_item.image_url.url - if image_url.startswith("data:image"): # base64 image + if re.match("^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) image_path = io.BytesIO(image_data) elif os.path.isfile(image_url): # local file @@ -112,7 +115,7 @@ def _process_request( else: # web uri image_path = requests.get(image_url, stream=True).raw - image = Image.open(image_path).convert("RGB") + image = np.array(Image.open(image_path).convert("RGB")) else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})