diff --git a/requirements.txt b/requirements.txt index c132683b..83523eae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ fire jieba rouge-chinese nltk -gradio>=3.36.0 +gradio==3.38.0 uvicorn pydantic==1.10.11 fastapi==0.95.1 diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 22f46266..50e96bb0 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -186,7 +186,7 @@ def get_train_args( # postprocess model_args model_args.compute_dtype = ( - torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32) + torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) ) model_args.model_max_length = data_args.cutoff_len diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index a43cee9e..234d924c 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -1,7 +1,7 @@ import os import json import gradio as gr -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional from transformers.utils import ( WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -27,7 +27,6 @@ CKPT_NAMES = [ ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME ] -CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]] def get_save_dir(*args) -> os.PathLike: @@ -38,7 +37,7 @@ def get_config_path() -> os.PathLike: return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def load_config() -> CONFIG_CLASS: +def load_config() -> Dict[str, Any]: try: with open(get_config_path(), "r", encoding="utf-8") as f: return json.load(f) @@ -46,20 +45,20 @@ def load_config() -> CONFIG_CLASS: return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} -def save_config( - config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None -) -> None: +def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) - config["lang"] = lang or config["lang"] + user_config = load_config() + user_config["lang"] = lang or user_config["lang"] if model_name: - config["last_model"] = model_name - config["path_dict"][model_name] = model_path + user_config["last_model"] = model_name + user_config["path_dict"][model_name] = model_path with open(get_config_path(), "w", encoding="utf-8") as f: - json.dump(config, f, indent=2, ensure_ascii=False) + json.dump(user_config, f, indent=2, ensure_ascii=False) -def get_model_path(config: Dict[str, Any], model_name: str) -> str: - return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") +def get_model_path(model_name: str) -> str: + user_config = load_config() + return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") def get_module(model_name: str) -> str: diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index ec935ee6..d6dd7eed 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -17,10 +17,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: unload_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False) - - elem_dict.update(dict( - info_box=info_box, load_btn=load_btn, unload_btn=unload_btn - )) + elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) elem_dict.update(dict(chat_box=chat_box, **chat_elems)) diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 86161467..8fad3cc9 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates -from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config +from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config from llmtuner.webui.utils import can_quantize if TYPE_CHECKING: @@ -12,7 +12,6 @@ if TYPE_CHECKING: def create_top() -> Dict[str, "Component"]: available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] - config = gr.State(value=load_config()) with gr.Row(): lang = gr.Dropdown(choices=["en", "zh"], scale=1) @@ -39,17 +38,17 @@ def create_top() -> Dict[str, "Component"]: model_name.change( list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False ).then( - get_model_path, [config, model_name], [model_path], queue=False + get_model_path, [model_name], [model_path], queue=False ).then( get_template, [model_name], [template], queue=False ) # do not save config since the below line will save - model_path.change(save_config, inputs=[config, lang, model_name, model_path]) + model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) finetuning_type.change( - list_checkpoint, [model_name, finetuning_type], [checkpoints] + list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False ).then( - can_quantize, [finetuning_type], [quantization_bit] + can_quantize, [finetuning_type], [quantization_bit], queue=False ) refresh_btn.click( @@ -57,7 +56,6 @@ def create_top() -> Dict[str, "Component"]: ) return dict( - config=config, lang=lang, model_name=model_name, model_path=model_path, diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 03d06144..f0f6542e 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -3,7 +3,7 @@ from gradio.components import Component # cannot use TYPE_CHECKING here from typing import Any, Dict, Generator, Optional from llmtuner.webui.chatter import WebChatModel -from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS +from llmtuner.webui.common import get_model_path, list_dataset, load_config from llmtuner.webui.locales import LOCALES from llmtuner.webui.manager import Manager from llmtuner.webui.runner import Runner @@ -21,8 +21,9 @@ class Engine: def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} - def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]: - lang = config.get("lang", None) or "en" + def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]: + user_config = load_config() + lang = user_config.get("lang", None) or "en" resume_dict = { "top.lang": {"value": lang}, @@ -33,9 +34,9 @@ class Engine: resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]} resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]} - if config.get("last_model", None): - resume_dict["top.model_name"] = {"value": config["last_model"]} - resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])} + if user_config.get("last_model", None): + resume_dict["top.model_name"] = {"value": user_config["last_model"]} + resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])} yield self._form_dict(resume_dict) diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 62f8e187..9388dcb1 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -9,12 +9,12 @@ from llmtuner.webui.components import ( create_export_tab, create_chat_box ) -from llmtuner.webui.common import load_config, save_config +from llmtuner.webui.common import save_config from llmtuner.webui.css import CSS from llmtuner.webui.engine import Engine -require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") +require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0") def create_ui() -> gr.Blocks: @@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks: with gr.Blocks(title="Web Tuner", css=CSS) as demo: engine.manager.all_elems["top"] = create_top() lang: "gr.Dropdown" = engine.manager.get_elem("top.lang") - config = engine.manager.get_elem("top.config") - model_name = engine.manager.get_elem("top.model_name") - model_path = engine.manager.get_elem("top.model_path") with gr.Tab("Train"): engine.manager.all_elems["train"] = create_train_tab(engine) @@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks: with gr.Tab("Export"): engine.manager.all_elems["export"] = create_export_tab(engine) - demo.load(engine.resume, [config], engine.manager.list_elems()) - - lang.change( - engine.change_lang, [lang], engine.manager.list_elems(), queue=False - ).then( - save_config, inputs=[config, lang, model_name, model_path] - ) + demo.load(engine.resume, outputs=engine.manager.list_elems()) + lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False) + lang.input(save_config, inputs=[lang], queue=False) return demo @@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks: engine = Engine(pure_chat=True) with gr.Blocks(title="Web Demo", css=CSS) as demo: - config = gr.State(value=load_config()) lang = gr.Dropdown(choices=["en", "zh"]) - - engine.manager.all_elems["top"] = dict(config=config, lang=lang) + engine.manager.all_elems["top"] = dict(lang=lang) chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems) - demo.load(engine.resume, [config], engine.manager.list_elems()) - - lang.change( - engine.change_lang, [lang], engine.manager.list_elems(), queue=False - ).then( - save_config, inputs=[config, lang] - ) + demo.load(engine.resume, outputs=engine.manager.list_elems()) + lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False) + lang.input(save_config, inputs=[lang], queue=False) return demo diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index ebd28463..118bd0a9 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -18,7 +18,6 @@ class Manager: def get_base_elems(self): return { - self.all_elems["top"]["config"], self.all_elems["top"]["lang"], self.all_elems["top"]["model_name"], self.all_elems["top"]["model_path"], diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index ed423978..9fde4d2c 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -14,7 +14,7 @@ from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp -from llmtuner.webui.common import get_module, get_save_dir +from llmtuner.webui.common import get_module, get_save_dir, load_config from llmtuner.webui.locales import ALERTS from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar @@ -74,6 +74,7 @@ class Runner: def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: get = lambda name: data[self.manager.get_elem(name)] + user_config = load_config() if get("top.checkpoints"): checkpoint_dir = ",".join([ @@ -89,7 +90,7 @@ class Runner: model_name_or_path=get("top.model_path"), do_train=True, overwrite_cache=False, - cache_dir=get("top.config").get("cache_dir", None), + cache_dir=user_config.get("cache_dir", None), checkpoint_dir=checkpoint_dir, finetuning_type=get("top.finetuning_type"), quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, @@ -142,6 +143,7 @@ class Runner: def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: get = lambda name: data[self.manager.get_elem(name)] + user_config = load_config() if get("top.checkpoints"): checkpoint_dir = ",".join([ @@ -160,7 +162,7 @@ class Runner: do_eval=True, overwrite_cache=False, predict_with_generate=True, - cache_dir=get("top.config").get("cache_dir", None), + cache_dir=user_config.get("cache_dir", None), checkpoint_dir=checkpoint_dir, finetuning_type=get("top.finetuning_type"), quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,