From 8a1da822efa03694f13721efb261036c240095ec Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 27 Aug 2024 17:03:20 +0800 Subject: [PATCH] Update chat.py --- src/llamafactory/api/chat.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index da055ac9..d64aec7d 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -24,6 +24,7 @@ import numpy as np from ..data import Role as DataRole from ..extras.logging import get_logger +from ..extras.misc import torch_gc from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from .common import dictify, jsonify from .protocol import ( @@ -184,6 +185,7 @@ async def create_chat_completion_response( prompt_length = response.prompt_length response_length += response.response_length + torch_gc() usage = ChatCompletionResponseUsage( prompt_tokens=prompt_length, completion_tokens=response_length, @@ -223,6 +225,7 @@ async def create_stream_chat_completion_response( completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) ) + torch_gc() yield _create_stream_chat_completion_chunk( completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP ) @@ -236,4 +239,5 @@ async def create_score_evaluation_response( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) + torch_gc() return ScoreEvaluationResponse(model=request.model, scores=scores)