fix async stream api response

This commit is contained in:
hiyouga 2024-05-04 16:11:18 +08:00
parent ed8f8be752
commit 941924fdbd
1 changed files with 15 additions and 15 deletions

View File

@ -38,7 +38,7 @@ ROLE_MAPPING = {
}
async def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@ -77,11 +77,23 @@ async def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[
return input_messages, system, tools
def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = await _process_request(request)
input_messages, system, tools = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
@ -124,23 +136,11 @@ async def create_chat_completion_response(
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
async def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = await _process_request(request)
input_messages, system, tools = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")