fix streaming response in API
This commit is contained in:
parent
e6603977f6
commit
4abd2485e1
|
@ -88,7 +88,7 @@ huggingface-cli login
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||||
- jieba, rouge_chinese and nltk (used at evaluation)
|
- jieba, rouge_chinese and nltk (used at evaluation)
|
||||||
- gradio and mdtex2html (used in web_demo.py)
|
- gradio and mdtex2html (used in web_demo.py)
|
||||||
- uvicorn and fastapi (used in api_demo.py)
|
- uvicorn, fastapi and sse_starlette (used in api_demo.py)
|
||||||
|
|
||||||
And **powerful GPUs**!
|
And **powerful GPUs**!
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ datasets>=2.12.0
|
||||||
accelerate>=0.19.0
|
accelerate>=0.19.0
|
||||||
peft>=0.3.0
|
peft>=0.3.0
|
||||||
trl>=0.4.4
|
trl>=0.4.4
|
||||||
|
sentencepiece
|
||||||
jieba
|
jieba
|
||||||
rouge_chinese
|
rouge_chinese
|
||||||
nltk
|
nltk
|
||||||
|
@ -11,4 +12,4 @@ gradio
|
||||||
mdtex2html
|
mdtex2html
|
||||||
uvicorn
|
uvicorn
|
||||||
fastapi
|
fastapi
|
||||||
sentencepiece
|
sse_starlette
|
||||||
|
|
|
@ -13,7 +13,7 @@ from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from starlette.responses import StreamingResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
|
@ -144,7 +144,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(gen_kwargs, request.model)
|
generate = predict(gen_kwargs, request.model)
|
||||||
return StreamingResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
generation_output = model.generate(**gen_kwargs)
|
generation_output = model.generate(**gen_kwargs)
|
||||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||||
|
@ -174,7 +174,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
if len(new_text) == 0:
|
if len(new_text) == 0:
|
||||||
|
@ -186,7 +186,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -194,7 +194,8 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
finish_reason="stop"
|
finish_reason="stop"
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -204,4 +205,4 @@ if __name__ == "__main__":
|
||||||
prompt_template = Template(data_args.prompt_template)
|
prompt_template = Template(data_args.prompt_template)
|
||||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||||
|
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
|
Loading…
Reference in New Issue