Compare commits

...

1 Commits

Author SHA1 Message Date
hoshi-hiyouga 8a1da822ef
Update chat.py 2024-08-27 17:03:20 +08:00
1 changed files with 4 additions and 0 deletions

View File

@ -24,6 +24,7 @@ import numpy as np
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
@ -184,6 +185,7 @@ async def create_chat_completion_response(
prompt_length = response.prompt_length prompt_length = response.prompt_length
response_length += response.response_length response_length += response.response_length
torch_gc()
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length, prompt_tokens=prompt_length,
completion_tokens=response_length, completion_tokens=response_length,
@ -223,6 +225,7 @@ async def create_stream_chat_completion_response(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
) )
torch_gc()
yield _create_stream_chat_completion_chunk( yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
) )
@ -236,4 +239,5 @@ async def create_score_evaluation_response(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
torch_gc()
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(model=request.model, scores=scores)