diff --git a/src/api_demo.py b/src/api_demo.py index 425e89fe..a0a82321 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -14,7 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from transformers import TextIteratorStreamer from sse_starlette import EventSourceResponse -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional from utils import ( Template, @@ -46,17 +46,17 @@ app.add_middleware( class ModelCard(BaseModel): id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "owner" + object: Optional[str] = "model" + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + owned_by: Optional[str] = "owner" root: Optional[str] = None parent: Optional[str] = None - permission: Optional[list] = None + permission: Optional[list] = [] class ModelList(BaseModel): - object: str = "list" - data: List[ModelCard] = [] + object: Optional[str] = "list" + data: Optional[List[ModelCard]] = [] class ChatMessage(BaseModel): @@ -74,8 +74,8 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None - max_length: Optional[int] = None - max_new_tokens: Optional[int] = None + n: Optional[int] = 1 + max_tokens: Optional[int] = None stream: Optional[bool] = False @@ -88,14 +88,30 @@ class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int class ChatCompletionResponse(BaseModel): - model: str - object: Literal["chat.completion", "chat.completion.chunk"] - choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + id: Optional[str] = "chatcmpl-default" + object: Literal["chat.completion"] created: Optional[int] = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: ChatCompletionResponseUsage + + +class ChatCompletionStreamResponse(BaseModel): + id: Optional[str] = "chatcmpl-default" + object: Literal["chat.completion.chunk"] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] @app.get("/v1/models", response_model=ModelList) @@ -135,12 +151,10 @@ async def create_chat_completion(request: ChatCompletionRequest): "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], "logits_processor": get_logits_processor() }) - if request.max_length: - gen_kwargs.pop("max_new_tokens", None) - gen_kwargs["max_length"] = request.max_length - if request.max_new_tokens: + + if request.max_tokens: gen_kwargs.pop("max_length", None) - gen_kwargs["max_new_tokens"] = request.max_new_tokens + gen_kwargs["max_new_tokens"] = request.max_tokens if request.stream: generate = predict(gen_kwargs, request.model) @@ -150,13 +164,19 @@ async def create_chat_completion(request: ChatCompletionRequest): outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) + usage = ChatCompletionResponseUsage( + prompt_tokens=len(inputs["input_ids"][0]), + completion_tokens=len(outputs), + total_tokens=len(inputs["input_ids"][0]) + len(outputs) + ) + choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop" ) - return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") + return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") async def predict(gen_kwargs: Dict[str, Any], model_id: str): @@ -173,7 +193,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): delta=DeltaMessage(role="assistant"), finish_reason=None ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in streamer: @@ -185,7 +205,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): delta=DeltaMessage(content=new_text), finish_reason=None ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( @@ -193,7 +213,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): delta=DeltaMessage(), finish_reason="stop" ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) yield "[DONE]"