diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 6d3d5afc..6d06d1d0 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -6,7 +6,7 @@ from typing_extensions import Annotated from ..chat import ChatModel from ..extras.misc import torch_gc -from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available +from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available from .chat import ( create_chat_completion_response, create_score_evaluation_response, @@ -22,7 +22,7 @@ from .protocol import ( ) -if is_fastapi_availble(): +if is_fastapi_available(): from fastapi import Depends, FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer diff --git a/src/llmtuner/api/chat.py b/src/llmtuner/api/chat.py index 2a703877..76ddc88d 100644 --- a/src/llmtuner/api/chat.py +++ b/src/llmtuner/api/chat.py @@ -3,7 +3,8 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole -from ..extras.packages import is_fastapi_availble +from ..extras.logging import get_logger +from ..extras.packages import is_fastapi_available from .common import dictify, jsonify from .protocol import ( ChatCompletionMessage, @@ -20,7 +21,9 @@ from .protocol import ( ) -if is_fastapi_availble(): +logger = get_logger(__name__) + +if is_fastapi_available(): from fastapi import HTTPException, status @@ -39,6 +42,9 @@ ROLE_MAPPING = { def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: + params = dictify(request) + logger.info(f"==== request ====\n{params}") + if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") diff --git a/src/llmtuner/api/common.py b/src/llmtuner/api/common.py index 5ad9a071..3e95d211 100644 --- a/src/llmtuner/api/common.py +++ b/src/llmtuner/api/common.py @@ -6,11 +6,11 @@ if TYPE_CHECKING: from pydantic import BaseModel -def dictify(data: "BaseModel") -> Dict[str, Any]: +def dictify(data: "BaseModel", **kwargs) -> Dict[str, Any]: try: # pydantic v2 - return data.model_dump(exclude_unset=True) + return data.model_dump(**kwargs) except AttributeError: # pydantic v1 - return data.dict(exclude_unset=True) + return data.dict(**kwargs) def jsonify(data: "BaseModel") -> str: diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index a7317eec..4c9e6492 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -20,7 +20,7 @@ def _get_package_version(name: str) -> "Version": return version.parse("0.0.0") -def is_fastapi_availble(): +def is_fastapi_available(): return _is_package_available("fastapi")