diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index fb800106..2a72f422 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from .base_engine import BaseEngine, Response -def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None: +def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: asyncio.set_event_loop(loop) loop.run_forever() @@ -49,7 +49,8 @@ class ChatModel: self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() - asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) + task = asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) + task.result() def chat( self,