implement webui resuming training
This commit is contained in:
parent
fde05cacfc
commit
accde3cd39
|
@ -134,6 +134,11 @@ class LogCallback(TrainerCallback):
|
|||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time
|
||||
)
|
||||
if self.runner is not None:
|
||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:.6f}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"]
|
||||
))
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
|
|
@ -19,8 +19,8 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
|
|
|
@ -8,10 +8,9 @@ from llmtuner.webui.utils import can_quantize
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_top(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
config = gr.State(value=load_config())
|
||||
|
||||
|
@ -38,11 +37,11 @@ def create_top(engine: "Engine") -> Dict[str, "Component"]:
|
|||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
get_model_path, [config, model_name], [model_path]
|
||||
get_model_path, [config, model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template]
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[config, lang, model_name, model_path])
|
||||
|
|
|
@ -24,9 +24,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
|
@ -128,6 +128,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
|
@ -139,15 +140,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
input_elems.add(output_dir)
|
||||
output_elems = [output_box, process_bar]
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
||||
output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
||||
resume_btn=resume_btn, output_box=output_box, loss_viewer=loss_viewer, process_bar=process_bar
|
||||
))
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
|
||||
process_bar.change(
|
||||
output_box.change(
|
||||
gen_plot,
|
||||
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
|
||||
loss_viewer,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from llmtuner.webui.chatter import WebChatModel
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
|
||||
|
@ -18,6 +18,9 @@ class Engine:
|
|||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
lang = config.get("lang", None) or "en"
|
||||
|
||||
|
@ -34,12 +37,14 @@ class Engine:
|
|||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
||||
|
||||
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
yield self._form_dict(resume_dict)
|
||||
|
||||
if self.runner.alive:
|
||||
pass # TODO: restore training
|
||||
if self.runner.alive: # TODO: resume eval
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
|
||||
resume_dict = {"train.resume_btn": {"value": True}}
|
||||
else:
|
||||
resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx
|
||||
resume_dict = {"train.output_dir": {"value": get_time()}}
|
||||
yield self._form_dict(resume_dict)
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
|
|
|
@ -21,7 +21,7 @@ def create_ui() -> gr.Blocks:
|
|||
engine = Engine(pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
engine.manager.all_elems["top"] = create_top(engine)
|
||||
engine.manager.all_elems["top"] = create_top()
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
||||
config = engine.manager.get_elem("top.config")
|
||||
model_name = engine.manager.get_elem("top.model_name")
|
||||
|
|
|
@ -27,6 +27,9 @@ class Runner:
|
|||
def __init__(self, manager: "Manager") -> None:
|
||||
self.manager = manager
|
||||
self.thread: "Thread" = None
|
||||
self.data: Dict["Component", Any] = None
|
||||
self.do_train = True
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.logger_handler = LoggerHandler()
|
||||
|
@ -199,14 +202,15 @@ class Runner:
|
|||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self.prepare(data, self._parse_train_args)
|
||||
yield from self.prepare(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self.prepare(data, self._parse_eval_args)
|
||||
yield from self.prepare(data, do_train=False)
|
||||
|
||||
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if is_training else self._parse_eval_args
|
||||
def prepare(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
||||
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
|
||||
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
|
@ -215,9 +219,10 @@ class Runner:
|
|||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor(lang, output_dir, is_training)
|
||||
yield from self.monitor()
|
||||
|
||||
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
@ -225,7 +230,7 @@ class Runner:
|
|||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
if is_training:
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue