From 614d3a996cd7a9444605b174d302ef9edd3c66c0 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 22 Jun 2023 20:46:24 +0800 Subject: [PATCH] update API --- src/api_demo.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/api_demo.py b/src/api_demo.py index 899aec11..a61fd632 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -1,13 +1,13 @@ # coding=utf-8 # Implements API for fine-tuned models. # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint +# Visit http://localhost:8000/docs for document. -import json import time import torch import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from threading import Thread from contextlib import asynccontextmanager @@ -36,12 +36,12 @@ app = FastAPI(lifespan=lifespan) class ChatMessage(BaseModel): - role: Literal["system", "user", "assistant"] + role: Literal["user", "assistant", "system"] content: str class DeltaMessage(BaseModel): - role: Optional[Literal["system", "user", "assistant"]] = None + role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None @@ -77,7 +77,10 @@ class ChatCompletionResponse(BaseModel): async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer, source_prefix + if request.messages[-1].role != "user": + raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content + prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == "system": source_prefix = prev_messages.pop(0).content