fix config, #1191
This commit is contained in:
parent
0d63584c03
commit
a6a04be2e6
|
@ -12,7 +12,7 @@ fire
|
|||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
gradio>=3.36.0
|
||||
gradio==3.38.0
|
||||
uvicorn
|
||||
pydantic==1.10.11
|
||||
fastapi==0.95.1
|
||||
|
|
|
@ -186,7 +186,7 @@ def get_train_args(
|
|||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
|
@ -27,7 +27,6 @@ CKPT_NAMES = [
|
|||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
|
@ -38,7 +37,7 @@ def get_config_path() -> os.PathLike:
|
|||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def load_config() -> CONFIG_CLASS:
|
||||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
@ -46,20 +45,20 @@ def load_config() -> CONFIG_CLASS:
|
|||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(
|
||||
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
|
||||
) -> None:
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
config["lang"] = lang or config["lang"]
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if model_name:
|
||||
config["last_model"] = model_name
|
||||
config["path_dict"][model_name] = model_path
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_model_path(config: Dict[str, Any], model_name: str) -> str:
|
||||
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
|
||||
|
||||
def get_module(model_name: str) -> str:
|
||||
|
|
|
@ -17,10 +17,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
elem_dict.update(dict(
|
||||
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
|
||||
))
|
||||
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
|||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -12,7 +12,6 @@ if TYPE_CHECKING:
|
|||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
config = gr.State(value=load_config())
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
||||
|
@ -39,17 +38,17 @@ def create_top() -> Dict[str, "Component"]:
|
|||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
get_model_path, [config, model_name], [model_path], queue=False
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[config, lang, model_name, model_path])
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(
|
||||
|
@ -57,7 +56,6 @@ def create_top() -> Dict[str, "Component"]:
|
|||
)
|
||||
|
||||
return dict(
|
||||
config=config,
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
|
|
@ -3,7 +3,7 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
|
|||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from llmtuner.webui.chatter import WebChatModel
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
|
@ -21,8 +21,9 @@ class Engine:
|
|||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
lang = config.get("lang", None) or "en"
|
||||
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
user_config = load_config()
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
resume_dict = {
|
||||
"top.lang": {"value": lang},
|
||||
|
@ -33,9 +34,9 @@ class Engine:
|
|||
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
|
||||
if config.get("last_model", None):
|
||||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
||||
if user_config.get("last_model", None):
|
||||
resume_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
yield self._form_dict(resume_dict)
|
||||
|
||||
|
|
|
@ -9,12 +9,12 @@ from llmtuner.webui.components import (
|
|||
create_export_tab,
|
||||
create_chat_box
|
||||
)
|
||||
from llmtuner.webui.common import load_config, save_config
|
||||
from llmtuner.webui.common import save_config
|
||||
from llmtuner.webui.css import CSS
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
|
@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks:
|
|||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
engine.manager.all_elems["top"] = create_top()
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
||||
config = engine.manager.get_elem("top.config")
|
||||
model_name = engine.manager.get_elem("top.model_name")
|
||||
model_path = engine.manager.get_elem("top.model_path")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks:
|
|||
with gr.Tab("Export"):
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
|
||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang, model_name, model_path]
|
||||
)
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks:
|
|||
engine = Engine(pure_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
config = gr.State(value=load_config())
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
|
||||
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
|
||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||
|
||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang]
|
||||
)
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ class Manager:
|
|||
|
||||
def get_base_elems(self):
|
||||
return {
|
||||
self.all_elems["top"]["config"],
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
self.all_elems["top"]["model_path"],
|
||||
|
|
|
@ -14,7 +14,7 @@ from llmtuner.extras.constants import TRAINING_STAGES
|
|||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_module, get_save_dir
|
||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
|
@ -74,6 +74,7 @@ class Runner:
|
|||
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
|
@ -89,7 +90,7 @@ class Runner:
|
|||
model_name_or_path=get("top.model_path"),
|
||||
do_train=True,
|
||||
overwrite_cache=False,
|
||||
cache_dir=get("top.config").get("cache_dir", None),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
|
@ -142,6 +143,7 @@ class Runner:
|
|||
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
|
@ -160,7 +162,7 @@ class Runner:
|
|||
do_eval=True,
|
||||
overwrite_cache=False,
|
||||
predict_with_generate=True,
|
||||
cache_dir=get("top.config").get("cache_dir", None),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
|
|
Loading…
Reference in New Issue