diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 5936955b..6d06d1d0 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -56,8 +56,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): if api_key and (auth is None or auth.credentials != api_key): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") @app.get( "/v1/models", @@ -77,12 +76,10 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ) async def create_chat_completion(request: ChatCompletionRequest): if not chat_model.engine.can_generate: - raise HTTPException( - status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") if request.stream: - generate = create_stream_chat_completion_response( - request, chat_model) + generate = create_stream_chat_completion_response(request, chat_model) return EventSourceResponse(generate, media_type="text/event-stream") else: return await create_chat_completion_response(request, chat_model) @@ -95,8 +92,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ) async def create_score_evaluation(request: ScoreEvaluationRequest): if chat_model.engine.can_generate: - raise HTTPException( - status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") return await create_score_evaluation_response(request, chat_model) diff --git a/src/llmtuner/api/chat.py b/src/llmtuner/api/chat.py index 3ab473d1..76ddc88d 100644 --- a/src/llmtuner/api/chat.py +++ b/src/llmtuner/api/chat.py @@ -3,8 +3,8 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole -from ..extras.packages import is_fastapi_available from ..extras.logging import get_logger +from ..extras.packages import is_fastapi_available from .common import dictify, jsonify from .protocol import ( ChatCompletionMessage, @@ -20,6 +20,7 @@ from .protocol import ( ScoreEvaluationResponse, ) + logger = get_logger(__name__) if is_fastapi_available(): @@ -41,13 +42,11 @@ ROLE_MAPPING = { def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: - params = dictify(request) logger.info(f"==== request ====\n{params}") if len(request.messages) == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") if request.messages[0].role == Role.SYSTEM: system = request.messages.pop(0).content @@ -55,37 +54,29 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s system = "" if len(request.messages) % 2 == 0: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - detail="Only supports u/a/u/a/u...") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") input_messages = [] for i, message in enumerate(request.messages): if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): name = message.tool_calls[0].function.name arguments = message.tool_calls[0].function.arguments - content = json.dumps( - {"name": name, "argument": arguments}, ensure_ascii=False) - input_messages.append( - {"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) + content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) + input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) else: - input_messages.append( - {"role": ROLE_MAPPING[message.role], "content": message.content}) + input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) tool_list = request.tools if isinstance(tool_list, list) and len(tool_list): try: - tools = json.dumps([dictify(tool.function) - for tool in tool_list], ensure_ascii=False) + tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) except Exception: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: tools = "" @@ -99,10 +90,8 @@ def _create_stream_chat_completion_chunk( index: Optional[int] = 0, finish_reason: Optional["Finish"] = None, ) -> str: - choice_data = ChatCompletionStreamResponseChoice( - index=index, delta=delta, finish_reason=finish_reason) - chunk = ChatCompletionStreamResponse( - id=completion_id, model=model, choices=[choice_data]) + choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) return jsonify(chunk) @@ -127,26 +116,21 @@ async def create_chat_completion_response( choices = [] for i, response in enumerate(responses): if tools: - result = chat_model.engine.template.format_tools.extract( - response.response_text) + result = chat_model.engine.template.format_tools.extract(response.response_text) else: result = response.response_text if isinstance(result, tuple): name, arguments = result function = Function(name=name, arguments=arguments) - tool_call = FunctionCall(id="call_{}".format( - uuid.uuid4().hex), function=function) - response_message = ChatCompletionMessage( - role=Role.ASSISTANT, tool_calls=[tool_call]) + tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) finish_reason = Finish.TOOL else: - response_message = ChatCompletionMessage( - role=Role.ASSISTANT, content=result) + response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH - choices.append(ChatCompletionResponseChoice( - index=i, message=response_message, finish_reason=finish_reason)) + choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)) prompt_length = response.prompt_length response_length += response.response_length @@ -165,16 +149,13 @@ async def create_stream_chat_completion_response( completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) input_messages, system, tools = _process_request(request) if tools: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot stream function calls.") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") if request.n > 1: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot stream multiple responses.") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") yield _create_stream_chat_completion_chunk( - completion_id=completion_id, model=request.model, delta=ChatCompletionMessage( - role=Role.ASSISTANT, content="") + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") ) async for new_token in chat_model.astream_chat( input_messages, @@ -188,8 +169,7 @@ async def create_stream_chat_completion_response( ): if len(new_token) != 0: yield _create_stream_chat_completion_chunk( - completion_id=completion_id, model=request.model, delta=ChatCompletionMessage( - content=new_token) + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) ) yield _create_stream_chat_completion_chunk( @@ -202,8 +182,7 @@ async def create_score_evaluation_response( request: "ScoreEvaluationRequest", chat_model: "ChatModel" ) -> "ScoreEvaluationResponse": if len(request.messages) == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) return ScoreEvaluationResponse(model=request.model, scores=scores)