fix streaming response in API
This commit is contained in:
parent
e6603977f6
commit
4abd2485e1
|
@ -88,7 +88,7 @@ huggingface-cli login
|
|||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- jieba, rouge_chinese and nltk (used at evaluation)
|
||||
- gradio and mdtex2html (used in web_demo.py)
|
||||
- uvicorn and fastapi (used in api_demo.py)
|
||||
- uvicorn, fastapi and sse_starlette (used in api_demo.py)
|
||||
|
||||
And **powerful GPUs**!
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ datasets>=2.12.0
|
|||
accelerate>=0.19.0
|
||||
peft>=0.3.0
|
||||
trl>=0.4.4
|
||||
sentencepiece
|
||||
jieba
|
||||
rouge_chinese
|
||||
nltk
|
||||
|
@ -11,4 +12,4 @@ gradio
|
|||
mdtex2html
|
||||
uvicorn
|
||||
fastapi
|
||||
sentencepiece
|
||||
sse_starlette
|
||||
|
|
|
@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from transformers import TextIteratorStreamer
|
||||
from starlette.responses import StreamingResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from utils import (
|
||||
|
@ -144,7 +144,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
|
||||
if request.stream:
|
||||
generate = predict(gen_kwargs, request.model)
|
||||
return StreamingResponse(generate, media_type="text/event-stream")
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
generation_output = model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
|
@ -174,7 +174,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in streamer:
|
||||
if len(new_text) == 0:
|
||||
|
@ -186,7 +186,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
|
@ -194,7 +194,8 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -204,4 +205,4 @@ if __name__ == "__main__":
|
|||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
|
Loading…
Reference in New Issue