update API

This commit is contained in:
hiyouga 2023-06-22 20:46:24 +08:00
parent 76ecb8c222
commit 614d3a996c
1 changed files with 7 additions and 4 deletions

View File

@ -1,13 +1,13 @@
# coding=utf-8 # coding=utf-8
# Implements API for fine-tuned models. # Implements API for fine-tuned models.
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import json
import time import time
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from threading import Thread from threading import Thread
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -36,12 +36,12 @@ app = FastAPI(lifespan=lifespan)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"] role: Literal["user", "assistant", "system"]
content: str content: str
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["system", "user", "assistant"]] = None role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None content: Optional[str] = None
@ -77,7 +77,10 @@ class ChatCompletionResponse(BaseModel):
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix global model, tokenizer, source_prefix
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system": if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content source_prefix = prev_messages.pop(0).content