implement webui resuming training

This commit is contained in:
hiyouga 2023-10-15 04:52:19 +08:00
parent fde05cacfc
commit accde3cd39
7 changed files with 42 additions and 26 deletions

View File

@ -134,6 +134,11 @@ class LogCallback(TrainerCallback):
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time
)
if self.runner is not None:
logger.info("{{'loss': {:.4f}, 'learning_rate': {:.6f}, 'epoch': {:.2f}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"]
))
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")

View File

@ -19,8 +19,8 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(

View File

@ -8,10 +8,9 @@ from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_top(engine: "Engine") -> Dict[str, "Component"]:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
config = gr.State(value=load_config())
@ -38,11 +37,11 @@ def create_top(engine: "Engine") -> Dict[str, "Component"]:
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
get_model_path, [config, model_name], [model_path]
get_model_path, [config, model_name], [model_path], queue=False
).then(
get_template, [model_name], [template]
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[config, lang, model_name, model_path])

View File

@ -24,9 +24,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
@ -128,6 +128,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
@ -139,15 +140,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.add(output_dir)
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, output_box=output_box, loss_viewer=loss_viewer, process_bar=process_bar
))
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
process_bar.change(
output_box.change(
gen_plot,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
loss_viewer,

View File

@ -1,6 +1,6 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
@ -18,6 +18,9 @@ class Engine:
self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en"
@ -34,12 +37,14 @@ class Engine:
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
yield self._form_dict(resume_dict)
if self.runner.alive:
pass # TODO: restore training
if self.runner.alive: # TODO: resume eval
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
resume_dict = {"train.resume_btn": {"value": True}}
else:
resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx
resume_dict = {"train.output_dir": {"value": get_time()}}
yield self._form_dict(resume_dict)
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return {

View File

@ -21,7 +21,7 @@ def create_ui() -> gr.Blocks:
engine = Engine(pure_chat=False)
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top(engine)
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
config = engine.manager.get_elem("top.config")
model_name = engine.manager.get_elem("top.model_name")

View File

@ -27,6 +27,9 @@ class Runner:
def __init__(self, manager: "Manager") -> None:
self.manager = manager
self.thread: "Thread" = None
self.data: Dict["Component", Any] = None
self.do_train = True
self.monitor_inputs: Dict[str, str] = None
self.aborted = False
self.running = False
self.logger_handler = LoggerHandler()
@ -199,14 +202,15 @@ class Runner:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self.prepare(data, self._parse_train_args)
yield from self.prepare(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
yield from self.prepare(data, self._parse_eval_args)
yield from self.prepare(data, do_train=False)
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if is_training else self._parse_eval_args
def prepare(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
error = self._initialize(lang, model_name, model_path, dataset)
if error:
yield error, gr.update(visible=False)
@ -215,9 +219,10 @@ class Runner:
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor(lang, output_dir, is_training)
yield from self.monitor()
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
while self.thread.is_alive():
time.sleep(2)
if self.aborted:
@ -225,7 +230,7 @@ class Runner:
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if is_training:
if self.do_train:
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else: