This commit is contained in:
hiyouga 2023-10-10 17:41:13 +08:00
parent 8e2ed6b8ce
commit e1dcb8e4dc
9 changed files with 22 additions and 29 deletions

View File

@ -1,5 +1,4 @@
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
@ -11,11 +10,10 @@ from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
if lazy_init:
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
else:
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init:
super().__init__(args)
def load_model(
@ -30,7 +28,7 @@ class WebChatModel(ChatModel):
flash_attn: bool,
shift_attn: bool,
rope_scaling: str
):
) -> Generator[str, None, None]:
if self.model is not None:
yield ALERTS["err_exists"][lang]
return
@ -65,7 +63,7 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str):
def unload_model(self, lang: str) -> Generator[str, None, None]:
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None
@ -81,16 +79,15 @@ class WebChatModel(ChatModel):
max_new_tokens: int,
top_p: float,
temperature: float
):
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
response = self.postprocess(response)
new_history = history + [(query, response)]
chatbot[-1] = [query, response]
chatbot[-1] = [query, self.postprocess(response)]
yield chatbot, new_history
def postprocess(self, response: str) -> str:

View File

@ -1,8 +1,7 @@
import json
import os
from typing import Any, Dict, Optional
import json
import gradio as gr
from typing import Any, Dict, Optional
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block

View File

@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box

View File

@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.utils import save_model

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates

View File

@ -1,8 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box

View File

@ -1,11 +1,12 @@
import gradio as gr
import logging
import os
import threading
import time
import logging
import threading
import gradio as gr
from typing import Any, Dict, Generator, List, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES