fix streaming response in API

This commit is contained in:
hiyouga 2023-07-05 22:42:31 +08:00
parent e6603977f6
commit 4abd2485e1
3 changed files with 10 additions and 8 deletions

View File

@ -88,7 +88,7 @@ huggingface-cli login
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- jieba, rouge_chinese and nltk (used at evaluation)
- gradio and mdtex2html (used in web_demo.py)
- uvicorn and fastapi (used in api_demo.py)
- uvicorn, fastapi and sse_starlette (used in api_demo.py)
And **powerful GPUs**!

View File

@ -4,6 +4,7 @@ datasets>=2.12.0
accelerate>=0.19.0
peft>=0.3.0
trl>=0.4.4
sentencepiece
jieba
rouge_chinese
nltk
@ -11,4 +12,4 @@ gradio
mdtex2html
uvicorn
fastapi
sentencepiece
sse_starlette

View File

@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from sse_starlette import EventSourceResponse
from typing import Any, Dict, List, Literal, Optional, Union
from utils import (
@ -144,7 +144,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream:
generate = predict(gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
@ -174,7 +174,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in streamer:
if len(new_text) == 0:
@ -186,7 +186,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@ -194,7 +194,8 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]"
if __name__ == "__main__":
@ -204,4 +205,4 @@ if __name__ == "__main__":
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)