diff --git a/README.md b/README.md index f45a312b..e548709e 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ huggingface-cli login - 🤗Transformers, Datasets, Accelerate, PEFT and TRL - jieba, rouge_chinese and nltk (used at evaluation) - gradio and mdtex2html (used in web_demo.py) -- uvicorn and fastapi (used in api_demo.py) +- uvicorn, fastapi and sse_starlette (used in api_demo.py) And **powerful GPUs**! diff --git a/requirements.txt b/requirements.txt index f0fbdf2d..50a67c33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ datasets>=2.12.0 accelerate>=0.19.0 peft>=0.3.0 trl>=0.4.4 +sentencepiece jieba rouge_chinese nltk @@ -11,4 +12,4 @@ gradio mdtex2html uvicorn fastapi -sentencepiece +sse_starlette diff --git a/src/api_demo.py b/src/api_demo.py index fd1d450a..425e89fe 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from transformers import TextIteratorStreamer -from starlette.responses import StreamingResponse +from sse_starlette import EventSourceResponse from typing import Any, Dict, List, Literal, Optional, Union from utils import ( @@ -144,7 +144,7 @@ async def create_chat_completion(request: ChatCompletionRequest): if request.stream: generate = predict(gen_kwargs, request.model) - return StreamingResponse(generate, media_type="text/event-stream") + return EventSourceResponse(generate, media_type="text/event-stream") generation_output = model.generate(**gen_kwargs) outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] @@ -174,7 +174,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in streamer: if len(new_text) == 0: @@ -186,7 +186,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield chunk.json(exclude_unset=True, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( index=0, @@ -194,7 +194,8 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): finish_reason="stop" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield chunk.json(exclude_unset=True, ensure_ascii=False) + yield "[DONE]" if __name__ == "__main__": @@ -204,4 +205,4 @@ if __name__ == "__main__": prompt_template = Template(data_args.prompt_template) source_prefix = data_args.source_prefix if data_args.source_prefix else "" - uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)