This commit is contained in:
hiyouga 2023-10-15 18:28:45 +08:00
parent 0d63584c03
commit a6a04be2e6
9 changed files with 40 additions and 57 deletions

View File

@ -12,7 +12,7 @@ fire
jieba
rouge-chinese
nltk
gradio>=3.36.0
gradio==3.38.0
uvicorn
pydantic==1.10.11
fastapi==0.95.1

View File

@ -186,7 +186,7 @@ def get_train_args(
# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len

View File

@ -1,7 +1,7 @@
import os
import json
import gradio as gr
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
@ -27,7 +27,6 @@ CKPT_NAMES = [
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
def get_save_dir(*args) -> os.PathLike:
@ -38,7 +37,7 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> CONFIG_CLASS:
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
@ -46,20 +45,20 @@ def load_config() -> CONFIG_CLASS:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
config["lang"] = lang or config["lang"]
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if model_name:
config["last_model"] = model_name
config["path_dict"][model_name] = model_path
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
json.dump(user_config, f, indent=2, ensure_ascii=False)
def get_model_path(config: Dict[str, Any], model_name: str) -> str:
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_module(model_name: str) -> str:

View File

@ -17,10 +17,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
elem_dict.update(dict(
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
))
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
@ -12,7 +12,6 @@ if TYPE_CHECKING:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
config = gr.State(value=load_config())
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
@ -39,17 +38,17 @@ def create_top() -> Dict[str, "Component"]:
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
get_model_path, [config, model_name], [model_path], queue=False
get_model_path, [model_name], [model_path], queue=False
).then(
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[config, lang, model_name, model_path])
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
can_quantize, [finetuning_type], [quantization_bit]
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(
@ -57,7 +56,6 @@ def create_top() -> Dict[str, "Component"]:
)
return dict(
config=config,
lang=lang,
model_name=model_name,
model_path=model_path,

View File

@ -3,7 +3,7 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
@ -21,8 +21,9 @@ class Engine:
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en"
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
user_config = load_config()
lang = user_config.get("lang", None) or "en"
resume_dict = {
"top.lang": {"value": lang},
@ -33,9 +34,9 @@ class Engine:
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
if config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
if user_config.get("last_model", None):
resume_dict["top.model_name"] = {"value": user_config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._form_dict(resume_dict)

View File

@ -9,12 +9,12 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
from llmtuner.webui.common import load_config, save_config
from llmtuner.webui.common import save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
def create_ui() -> gr.Blocks:
@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
config = engine.manager.get_elem("top.config")
model_name = engine.manager.get_elem("top.model_name")
model_path = engine.manager.get_elem("top.model_path")
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)
demo.load(engine.resume, [config], engine.manager.list_elems())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)
return demo
@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks:
engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
config = gr.State(value=load_config())
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
engine.manager.all_elems["top"] = dict(lang=lang)
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
demo.load(engine.resume, [config], engine.manager.list_elems())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)
return demo

View File

@ -18,7 +18,6 @@ class Manager:
def get_base_elems(self):
return {
self.all_elems["top"]["config"],
self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"],
self.all_elems["top"]["model_path"],

View File

@ -14,7 +14,7 @@ from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_module, get_save_dir
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@ -74,6 +74,7 @@ class Runner:
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
@ -89,7 +90,7 @@ class Runner:
model_name_or_path=get("top.model_path"),
do_train=True,
overwrite_cache=False,
cache_dir=get("top.config").get("cache_dir", None),
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
@ -142,6 +143,7 @@ class Runner:
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
@ -160,7 +162,7 @@ class Runner:
do_eval=True,
overwrite_cache=False,
predict_with_generate=True,
cache_dir=get("top.config").get("cache_dir", None),
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,