forked from p04798526/LLaMA-Factory-Mirror
fix api server
This commit is contained in:
parent
d2a676c8ba
commit
08464183b9
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
@ -63,6 +65,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
|
@ -93,8 +97,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
|
||||
|
||||
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
generate = predict(query, history, system, request)
|
||||
generate = stream_chat_completion(query, history, system, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
responses = chat_model.chat(
|
||||
|
@ -125,7 +134,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
||||
|
@ -168,7 +177,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_score, request)
|
||||
|
||||
def get_score(request: ScoreEvaluationRequest):
|
||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
|
||||
|
@ -178,4 +192,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||
if __name__ == "__main__":
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
|
|
@ -152,7 +152,6 @@ class ChatModel:
|
|||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||
pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
**kwargs
|
||||
).to(device)
|
||||
|
|
Loading…
Reference in New Issue