update API

This commit is contained in:
hiyouga 2023-06-22 20:46:24 +08:00
parent 76ecb8c222
commit 614d3a996c
1 changed files with 7 additions and 4 deletions

View File

@ -1,13 +1,13 @@
# coding=utf-8
# Implements API for fine-tuned models.
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import json
import time
import torch
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from threading import Thread
from contextlib import asynccontextmanager
@ -36,12 +36,12 @@ app = FastAPI(lifespan=lifespan)
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["system", "user", "assistant"]] = None
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
@ -77,7 +77,10 @@ class ChatCompletionResponse(BaseModel):
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content