Compare commits
1 Commits
main
...
hiyouga-pa
Author | SHA1 | Date |
---|---|---|
hoshi-hiyouga | 8a1da822ef |
|
@ -24,6 +24,7 @@ import numpy as np
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.misc import torch_gc
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
|
@ -184,6 +185,7 @@ async def create_chat_completion_response(
|
||||||
prompt_length = response.prompt_length
|
prompt_length = response.prompt_length
|
||||||
response_length += response.response_length
|
response_length += response.response_length
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
prompt_tokens=prompt_length,
|
prompt_tokens=prompt_length,
|
||||||
completion_tokens=response_length,
|
completion_tokens=response_length,
|
||||||
|
@ -223,6 +225,7 @@ async def create_stream_chat_completion_response(
|
||||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
yield _create_stream_chat_completion_chunk(
|
yield _create_stream_chat_completion_chunk(
|
||||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||||
)
|
)
|
||||||
|
@ -236,4 +239,5 @@ async def create_score_evaluation_response(
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
|
torch_gc()
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
Loading…
Reference in New Issue