forked from p04798526/LLaMA-Factory-Mirror
update API
This commit is contained in:
parent
76ecb8c222
commit
614d3a996c
|
@ -1,13 +1,13 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Implements API for fine-tuned models.
|
# Implements API for fine-tuned models.
|
||||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
# Visit http://localhost:8000/docs for document.
|
||||||
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
@ -36,12 +36,12 @@ app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Literal["system", "user", "assistant"]
|
role: Literal["user", "assistant", "system"]
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
role: Optional[Literal["system", "user", "assistant"]] = None
|
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,7 +77,10 @@ class ChatCompletionResponse(BaseModel):
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
global model, tokenizer, source_prefix
|
global model, tokenizer, source_prefix
|
||||||
|
|
||||||
|
if request.messages[-1].role != "user":
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
query = request.messages[-1].content
|
query = request.messages[-1].content
|
||||||
|
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||||
source_prefix = prev_messages.pop(0).content
|
source_prefix = prev_messages.pop(0).content
|
||||||
|
|
Loading…
Reference in New Issue