This commit is contained in:
hiyouga 2024-06-18 22:08:56 +08:00
parent 24c160df3d
commit c96264bc47
4 changed files with 4 additions and 16 deletions

View File

@ -50,11 +50,6 @@ class BaseEngine(ABC):
generating_args: "GeneratingArguments",
) -> None: ...
@abstractmethod
async def start(
self,
) -> None: ...
@abstractmethod
async def chat(
self,

View File

@ -49,8 +49,6 @@ class ChatModel:
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
task = asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
task.result()
def chat(
self,

View File

@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod
def _process_args(
@ -259,9 +260,6 @@ class HuggingfaceEngine(BaseEngine):
return scores
async def start(self) -> None:
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
async def chat(
self,
messages: Sequence[Dict[str, str]],
@ -286,7 +284,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
async with self._semaphore:
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@ -314,7 +312,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
async with self._semaphore:
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
@ -333,6 +331,6 @@ class HuggingfaceEngine(BaseEngine):
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self._semaphore:
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)

View File

@ -183,9 +183,6 @@ class VllmEngine(BaseEngine):
)
return result_generator
async def start(self) -> None:
pass
async def chat(
self,
messages: Sequence[Dict[str, str]],