forked from p04798526/LLaMA-Factory-Mirror
update API
This commit is contained in:
parent
76ecb8c222
commit
614d3a996c
|
@ -1,13 +1,13 @@
|
|||
# coding=utf-8
|
||||
# Implements API for fine-tuned models.
|
||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
# Visit http://localhost:8000/docs for document.
|
||||
|
||||
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from threading import Thread
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
@ -36,12 +36,12 @@ app = FastAPI(lifespan=lifespan)
|
|||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["system", "user", "assistant"]] = None
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -77,7 +77,10 @@ class ChatCompletionResponse(BaseModel):
|
|||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer, source_prefix
|
||||
|
||||
if request.messages[-1].role != "user":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
source_prefix = prev_messages.pop(0).content
|
||||
|
|
Loading…
Reference in New Issue