fix streaming response in API

This commit is contained in:
hiyouga 2023-07-05 22:42:31 +08:00
parent e6603977f6
commit 4abd2485e1
3 changed files with 10 additions and 8 deletions

View File

@ -88,7 +88,7 @@ huggingface-cli login
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL - 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- jieba, rouge_chinese and nltk (used at evaluation) - jieba, rouge_chinese and nltk (used at evaluation)
- gradio and mdtex2html (used in web_demo.py) - gradio and mdtex2html (used in web_demo.py)
- uvicorn and fastapi (used in api_demo.py) - uvicorn, fastapi and sse_starlette (used in api_demo.py)
And **powerful GPUs**! And **powerful GPUs**!

View File

@ -4,6 +4,7 @@ datasets>=2.12.0
accelerate>=0.19.0 accelerate>=0.19.0
peft>=0.3.0 peft>=0.3.0
trl>=0.4.4 trl>=0.4.4
sentencepiece
jieba jieba
rouge_chinese rouge_chinese
nltk nltk
@ -11,4 +12,4 @@ gradio
mdtex2html mdtex2html
uvicorn uvicorn
fastapi fastapi
sentencepiece sse_starlette

View File

@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse from sse_starlette import EventSourceResponse
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from utils import ( from utils import (
@ -144,7 +144,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(gen_kwargs, request.model) generate = predict(gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs) generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
@ -174,7 +174,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in streamer: for new_text in streamer:
if len(new_text) == 0: if len(new_text) == 0:
@ -186,7 +186,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -194,7 +194,8 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]"
if __name__ == "__main__": if __name__ == "__main__":
@ -204,4 +205,4 @@ if __name__ == "__main__":
prompt_template = Template(data_args.prompt_template) prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else "" source_prefix = data_args.source_prefix if data_args.source_prefix else ""
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)