diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index 8afcdb7e..1db329c2 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -1,5 +1,4 @@ -import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from llmtuner.chat.stream_chat import ChatModel from llmtuner.extras.misc import torch_gc @@ -11,11 +10,10 @@ from llmtuner.webui.locales import ALERTS class WebChatModel(ChatModel): def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: - if lazy_init: - self.model = None - self.tokenizer = None - self.generating_args = GeneratingArguments() - else: + self.model = None + self.tokenizer = None + self.generating_args = GeneratingArguments() + if not lazy_init: super().__init__(args) def load_model( @@ -30,7 +28,7 @@ class WebChatModel(ChatModel): flash_attn: bool, shift_attn: bool, rope_scaling: str - ): + ) -> Generator[str, None, None]: if self.model is not None: yield ALERTS["err_exists"][lang] return @@ -65,7 +63,7 @@ class WebChatModel(ChatModel): yield ALERTS["info_loaded"][lang] - def unload_model(self, lang: str): + def unload_model(self, lang: str) -> Generator[str, None, None]: yield ALERTS["info_unloading"][lang] self.model = None self.tokenizer = None @@ -81,16 +79,15 @@ class WebChatModel(ChatModel): max_new_tokens: int, top_p: float, temperature: float - ): + ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: chatbot.append([query, ""]) response = "" for new_text in self.stream_chat( query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text - response = self.postprocess(response) new_history = history + [(query, response)] - chatbot[-1] = [query, response] + chatbot[-1] = [query, self.postprocess(response)] yield chatbot, new_history def postprocess(self, response: str) -> str: diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index ef8d2adc..fd0044b6 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -1,8 +1,7 @@ -import json import os -from typing import Any, Dict, Optional - +import json import gradio as gr +from typing import Any, Dict, Optional from transformers.utils import ( WEIGHTS_NAME, WEIGHTS_INDEX_NAME, diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 928a568c..9de397eb 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -1,6 +1,5 @@ -from typing import TYPE_CHECKING, Dict, Optional, Tuple - import gradio as gr +from typing import TYPE_CHECKING, Dict, Optional, Tuple if TYPE_CHECKING: from gradio.blocks import Block diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 089d02e5..9a4d5a8d 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -1,5 +1,5 @@ -from typing import TYPE_CHECKING, Dict import gradio as gr +from typing import TYPE_CHECKING, Dict from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index 1b18fca0..6d11c003 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,5 +1,5 @@ -from typing import TYPE_CHECKING, Dict import gradio as gr +from typing import TYPE_CHECKING, Dict from llmtuner.webui.utils import save_model diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 49505b67..e4488a4f 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -1,6 +1,5 @@ -from typing import TYPE_CHECKING, Dict - import gradio as gr +from typing import TYPE_CHECKING, Dict from llmtuner.webui.chat import WebChatModel from llmtuner.webui.components.chatbot import create_chat_box diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index e675009d..8c4698dd 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -1,6 +1,5 @@ -from typing import TYPE_CHECKING, Dict - import gradio as gr +from typing import TYPE_CHECKING, Dict from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 050a44da..5b74034e 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -1,8 +1,7 @@ +import gradio as gr from typing import TYPE_CHECKING, Dict from transformers.trainer_utils import SchedulerType -import gradio as gr - from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 3cc32332..08d4557c 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -1,11 +1,12 @@ -import gradio as gr -import logging import os -import threading import time +import logging +import threading +import gradio as gr +from typing import Any, Dict, Generator, List, Tuple + import transformers from transformers.trainer import TRAINING_ARGS_NAME -from typing import Any, Dict, Generator, List, Tuple from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES