fix api server

This commit is contained in:
hiyouga 2024-01-07 17:14:42 +08:00
parent d2a676c8ba
commit 08464183b9
2 changed files with 18 additions and 5 deletions

View File

@ -1,4 +1,6 @@
import os
import json
import asyncio
from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
@ -63,6 +65,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_headers=["*"],
)
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo")
@ -93,8 +97,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
if request.stream:
generate = predict(query, history, system, request)
generate = stream_chat_completion(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat(
@ -125,7 +134,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
@ -168,7 +177,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)
@ -178,4 +192,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if __name__ == "__main__":
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@ -152,7 +152,6 @@ class ChatModel:
padding=True,
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
pad_to_multiple_of=8,
return_tensors="pt",
**kwargs
).to(device)