forked from p04798526/LLaMA-Factory-Mirror
update api to match langchain
This commit is contained in:
parent
233f20864b
commit
84a06318d4
|
@ -14,7 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from contextlib import asynccontextmanager
|
||||
from transformers import TextIteratorStreamer
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from utils import (
|
||||
Template,
|
||||
|
@ -46,17 +46,17 @@ app.add_middleware(
|
|||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
permission: Optional[list] = []
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
object: Optional[str] = "list"
|
||||
data: Optional[List[ModelCard]] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
@ -74,8 +74,8 @@ class ChatCompletionRequest(BaseModel):
|
|||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
max_length: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
|
@ -88,14 +88,30 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
|
@ -135,12 +151,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||
"logits_processor": get_logits_processor()
|
||||
})
|
||||
if request.max_length:
|
||||
gen_kwargs.pop("max_new_tokens", None)
|
||||
gen_kwargs["max_length"] = request.max_length
|
||||
if request.max_new_tokens:
|
||||
|
||||
if request.max_tokens:
|
||||
gen_kwargs.pop("max_length", None)
|
||||
gen_kwargs["max_new_tokens"] = request.max_new_tokens
|
||||
gen_kwargs["max_new_tokens"] = request.max_tokens
|
||||
|
||||
if request.stream:
|
||||
generate = predict(gen_kwargs, request.model)
|
||||
|
@ -150,13 +164,19 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=len(inputs["input_ids"][0]),
|
||||
completion_tokens=len(outputs),
|
||||
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
finish_reason="stop"
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||
|
||||
|
||||
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||
|
@ -173,7 +193,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in streamer:
|
||||
|
@ -185,7 +205,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
|
@ -193,7 +213,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
|
||||
|
|
Loading…
Reference in New Issue