forked from p04798526/LLaMA-Factory-Mirror
update api to match langchain
This commit is contained in:
parent
233f20864b
commit
84a06318d4
|
@ -14,7 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
Template,
|
Template,
|
||||||
|
@ -46,17 +46,17 @@ app.add_middleware(
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = "model"
|
object: Optional[str] = "model"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: str = "owner"
|
owned_by: Optional[str] = "owner"
|
||||||
root: Optional[str] = None
|
root: Optional[str] = None
|
||||||
parent: Optional[str] = None
|
parent: Optional[str] = None
|
||||||
permission: Optional[list] = None
|
permission: Optional[list] = []
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(BaseModel):
|
||||||
object: str = "list"
|
object: Optional[str] = "list"
|
||||||
data: List[ModelCard] = []
|
data: Optional[List[ModelCard]] = []
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
@ -74,8 +74,8 @@ class ChatCompletionRequest(BaseModel):
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
max_length: Optional[int] = None
|
n: Optional[int] = 1
|
||||||
max_new_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,14 +88,30 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
finish_reason: Optional[Literal["stop", "length"]]
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
model: str
|
id: Optional[str] = "chatcmpl-default"
|
||||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
object: Literal["chat.completion"]
|
||||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: ChatCompletionResponseUsage
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: Optional[str] = "chatcmpl-default"
|
||||||
|
object: Literal["chat.completion.chunk"]
|
||||||
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
|
@ -135,12 +151,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||||
"logits_processor": get_logits_processor()
|
"logits_processor": get_logits_processor()
|
||||||
})
|
})
|
||||||
if request.max_length:
|
|
||||||
gen_kwargs.pop("max_new_tokens", None)
|
if request.max_tokens:
|
||||||
gen_kwargs["max_length"] = request.max_length
|
|
||||||
if request.max_new_tokens:
|
|
||||||
gen_kwargs.pop("max_length", None)
|
gen_kwargs.pop("max_length", None)
|
||||||
gen_kwargs["max_new_tokens"] = request.max_new_tokens
|
gen_kwargs["max_new_tokens"] = request.max_tokens
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(gen_kwargs, request.model)
|
generate = predict(gen_kwargs, request.model)
|
||||||
|
@ -150,13 +164,19 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
usage = ChatCompletionResponseUsage(
|
||||||
|
prompt_tokens=len(inputs["input_ids"][0]),
|
||||||
|
completion_tokens=len(outputs),
|
||||||
|
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
|
||||||
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=0,
|
index=0,
|
||||||
message=ChatMessage(role="assistant", content=response),
|
message=ChatMessage(role="assistant", content=response),
|
||||||
finish_reason="stop"
|
finish_reason="stop"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||||
|
|
||||||
|
|
||||||
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
|
@ -173,7 +193,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role="assistant"),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
|
@ -185,7 +205,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
delta=DeltaMessage(content=new_text),
|
delta=DeltaMessage(content=new_text),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
@ -193,7 +213,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
finish_reason="stop"
|
finish_reason="stop"
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue