Update chat.py
This commit is contained in:
parent
0f5a0f64f7
commit
8a1da822ef
|
@ -24,6 +24,7 @@ import numpy as np
|
|||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
|
@ -184,6 +185,7 @@ async def create_chat_completion_response(
|
|||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
torch_gc()
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
|
@ -223,6 +225,7 @@ async def create_stream_chat_completion_response(
|
|||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
||||
)
|
||||
|
||||
torch_gc()
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
|
@ -236,4 +239,5 @@ async def create_score_evaluation_response(
|
|||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
torch_gc()
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
|
|
Loading…
Reference in New Issue