update api to match langchain

This commit is contained in:
hiyouga 2023-07-07 20:35:39 +08:00
parent 233f20864b
commit 84a06318d4
1 changed files with 42 additions and 22 deletions

View File

@ -14,7 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional
from utils import ( from utils import (
Template, Template,
@ -46,17 +46,17 @@ app.add_middleware(
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: str = "model" object: Optional[str] = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner" owned_by: Optional[str] = "owner"
root: Optional[str] = None root: Optional[str] = None
parent: Optional[str] = None parent: Optional[str] = None
permission: Optional[list] = None permission: Optional[list] = []
class ModelList(BaseModel): class ModelList(BaseModel):
object: str = "list" object: Optional[str] = "list"
data: List[ModelCard] = [] data: Optional[List[ModelCard]] = []
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@ -74,8 +74,8 @@ class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage] messages: List[ChatMessage]
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
max_length: Optional[int] = None n: Optional[int] = 1
max_new_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: Optional[bool] = False stream: Optional[bool] = False
@ -88,14 +88,30 @@ class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
model: str id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion", "chat.completion.chunk"] object: Literal["chat.completion"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
@ -135,12 +151,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"], "top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
"logits_processor": get_logits_processor() "logits_processor": get_logits_processor()
}) })
if request.max_length:
gen_kwargs.pop("max_new_tokens", None) if request.max_tokens:
gen_kwargs["max_length"] = request.max_length
if request.max_new_tokens:
gen_kwargs.pop("max_length", None) gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = request.max_new_tokens gen_kwargs["max_new_tokens"] = request.max_tokens
if request.stream: if request.stream:
generate = predict(gen_kwargs, request.model) generate = predict(gen_kwargs, request.model)
@ -150,13 +164,19 @@ async def create_chat_completion(request: ChatCompletionRequest):
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
usage = ChatCompletionResponseUsage(
prompt_tokens=len(inputs["input_ids"][0]),
completion_tokens=len(outputs),
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role="assistant", content=response),
finish_reason="stop" finish_reason="stop"
) )
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
async def predict(gen_kwargs: Dict[str, Any], model_id: str): async def predict(gen_kwargs: Dict[str, Any], model_id: str):
@ -173,7 +193,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in streamer: for new_text in streamer:
@ -185,7 +205,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(content=new_text), delta=DeltaMessage(content=new_text),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
@ -193,7 +213,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"