fix #4335
This commit is contained in:
parent
24c160df3d
commit
c96264bc47
|
@ -50,11 +50,6 @@ class BaseEngine(ABC):
|
|||
generating_args: "GeneratingArguments",
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def start(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
|
|
|
@ -49,8 +49,6 @@ class ChatModel:
|
|||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
task = asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||
task.result()
|
||||
|
||||
def chat(
|
||||
self,
|
||||
|
|
|
@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
) # must after fixing tokenizer to resize vocab
|
||||
self.generating_args = generating_args.to_dict()
|
||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
|
@ -259,9 +260,6 @@ class HuggingfaceEngine(BaseEngine):
|
|||
|
||||
return scores
|
||||
|
||||
async def start(self) -> None:
|
||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
|
@ -286,7 +284,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
|
||||
|
@ -314,7 +312,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
stream = self._stream_chat(*input_args)
|
||||
while True:
|
||||
|
@ -333,6 +331,6 @@ class HuggingfaceEngine(BaseEngine):
|
|||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||
|
|
|
@ -183,9 +183,6 @@ class VllmEngine(BaseEngine):
|
|||
)
|
||||
return result_generator
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
|
|
Loading…
Reference in New Issue