diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chatter.py similarity index 53% rename from src/llmtuner/webui/chat.py rename to src/llmtuner/webui/chatter.py index 1db329c2..712f2c75 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chatter.py @@ -1,69 +1,73 @@ -from typing import Any, Dict, Generator, List, Optional, Tuple +from gradio.components import Component # cannot use TYPE_CHECKING here +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from llmtuner.chat.stream_chat import ChatModel from llmtuner.extras.misc import torch_gc from llmtuner.hparams import GeneratingArguments -from llmtuner.webui.common import get_model_path, get_save_dir +from llmtuner.webui.common import get_save_dir from llmtuner.webui.locales import ALERTS +if TYPE_CHECKING: + from llmtuner.webui.manager import Manager + class WebChatModel(ChatModel): - def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: + def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None: + self.manager = manager self.model = None self.tokenizer = None self.generating_args = GeneratingArguments() if not lazy_init: - super().__init__(args) + super().__init__() - def load_model( - self, - lang: str, - model_name: str, - checkpoints: List[str], - finetuning_type: str, - quantization_bit: str, - template: str, - system_prompt: str, - flash_attn: bool, - shift_attn: bool, - rope_scaling: str - ) -> Generator[str, None, None]: - if self.model is not None: + @property + def loaded(self) -> bool: + return self.model is not None + + def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: + get = lambda name: data[self.manager.get_elem(name)] + lang = get("top.lang") + + if self.loaded: yield ALERTS["err_exists"][lang] return - if not model_name: + if not get("top.model_name"): yield ALERTS["err_no_model"][lang] return - model_name_or_path = get_model_path(model_name) - if not model_name_or_path: + if not get("top.model_path"): yield ALERTS["err_no_path"][lang] return - if checkpoints: - checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) + if get("top.checkpoints"): + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) else: checkpoint_dir = None yield ALERTS["info_loading"][lang] args = dict( - model_name_or_path=model_name_or_path, + model_name_or_path=get("top.model_path"), checkpoint_dir=checkpoint_dir, - finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, - template=template, - system_prompt=system_prompt, - flash_attn=flash_attn, - shift_attn=shift_attn, - rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None + finetuning_type=get("top.finetuning_type"), + quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + template=get("top.template"), + system_prompt=get("top.system_prompt"), + flash_attn=get("top.flash_attn"), + shift_attn=get("top.shift_attn"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None ) super().__init__(args) yield ALERTS["info_loaded"][lang] - def unload_model(self, lang: str) -> Generator[str, None, None]: + def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: + get = lambda name: data[self.manager.get_elem(name)] + lang = get("top.lang") + yield ALERTS["info_unloading"][lang] self.model = None self.tokenizer = None diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index fd0044b6..a43cee9e 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -1,7 +1,7 @@ import os import json import gradio as gr -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from transformers.utils import ( WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -11,7 +11,7 @@ from transformers.utils import ( ADAPTER_SAFE_WEIGHTS_NAME ) -from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES +from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES DEFAULT_CACHE_DIR = "cache" @@ -27,6 +27,7 @@ CKPT_NAMES = [ ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME ] +CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]] def get_save_dir(*args) -> os.PathLike: @@ -37,7 +38,7 @@ def get_config_path() -> os.PathLike: return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def load_config() -> Dict[str, Any]: +def load_config() -> CONFIG_CLASS: try: with open(get_config_path(), "r", encoding="utf-8") as f: return json.load(f) @@ -45,20 +46,24 @@ def load_config() -> Dict[str, Any]: return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} -def save_config(lang: str, model_name: str, model_path: str) -> None: +def save_config( + config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None +) -> None: os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) - user_config = load_config() - user_config["lang"] = lang or user_config["lang"] + config["lang"] = lang or config["lang"] if model_name: - user_config["last_model"] = model_name - user_config["path_dict"][model_name] = model_path + config["last_model"] = model_name + config["path_dict"][model_name] = model_path with open(get_config_path(), "w", encoding="utf-8") as f: - json.dump(user_config, f, indent=2, ensure_ascii=False) + json.dump(config, f, indent=2, ensure_ascii=False) -def get_model_path(model_name: str) -> str: - user_config = load_config() - return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) +def get_model_path(config: Dict[str, Any], model_name: str) -> str: + return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") + + +def get_module(model_name: str) -> str: + return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj") def get_template(model_name: str) -> str: diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 9de397eb..57f14d4a 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -4,13 +4,15 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple if TYPE_CHECKING: from gradio.blocks import Block from gradio.components import Component - from llmtuner.webui.chat import WebChatModel + from llmtuner.webui.engine import Engine def create_chat_box( - chat_model: "WebChatModel", + engine: "Engine", visible: Optional[bool] = False ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: + elem_dict = dict() + with gr.Box(visible=visible) as chat_box: chatbot = gr.Chatbot() @@ -22,14 +24,20 @@ def create_chat_box( with gr.Column(scale=1): clear_btn = gr.Button() - max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1) - top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01) - temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01) + gen_kwargs = engine.chatter.generating_args + max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1) + top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) + temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) + + elem_dict.update(dict( + system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn, + max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + )) history = gr.State([]) submit_btn.click( - chat_model.predict, + engine.chatter.predict, [chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, history], show_progress=True @@ -39,12 +47,4 @@ def create_chat_box( clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) - return chat_box, chatbot, history, dict( - system=system, - query=query, - submit_btn=submit_btn, - clear_btn=clear_btn, - max_new_tokens=max_new_tokens, - top_p=top_p, - temperature=temperature - ) + return chat_box, chatbot, history, elem_dict diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 9a4d5a8d..f7bffae9 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -7,19 +7,28 @@ from llmtuner.webui.utils import can_preview, get_preview if TYPE_CHECKING: from gradio.components import Component - from llmtuner.webui.runner import Runner + from llmtuner.webui.engine import Engine -def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: +def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) data_preview_btn = gr.Button(interactive=False, scale=1) - preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) + + input_elems.update({dataset_dir, dataset}) + elem_dict.update(dict( + dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn + )) + + preview_box, preview_count, preview_samples, close_btn = create_preview_box() + data_preview_btn.click( get_preview, [dataset_dir, dataset], @@ -27,17 +36,31 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict queue=False ) + elem_dict.update(dict( + preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn + )) + with gr.Row(): cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) max_samples = gr.Textbox(value="100000") batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) predict = gr.Checkbox(value=True) + input_elems.update({cutoff_len, max_samples, batch_size, predict}) + elem_dict.update(dict( + cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict + )) + with gr.Row(): max_new_tokens = gr.Slider(10, 2048, value=128, step=1) top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) + input_elems.update({max_new_tokens, top_p, temperature}) + elem_dict.update(dict( + max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + )) + with gr.Row(): cmd_preview_btn = gr.Button() start_btn = gr.Button() @@ -49,53 +72,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict with gr.Box(): output_box = gr.Markdown() - input_components = [ - top_elems["lang"], - top_elems["model_name"], - top_elems["checkpoints"], - top_elems["finetuning_type"], - top_elems["quantization_bit"], - top_elems["template"], - top_elems["system_prompt"], - top_elems["flash_attn"], - top_elems["shift_attn"], - top_elems["rope_scaling"], - dataset_dir, - dataset, - cutoff_len, - max_samples, - batch_size, - predict, - max_new_tokens, - top_p, - temperature - ] + output_elems = [output_box, process_bar] + elem_dict.update(dict( + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box + )) - output_components = [ - output_box, - process_bar - ] + cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) + start_btn.click(engine.runner.run_eval, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort, queue=False) - cmd_preview_btn.click(runner.preview_eval, input_components, output_components) - start_btn.click(runner.run_eval, input_components, output_components) - stop_btn.click(runner.set_abort, queue=False) - - return dict( - dataset_dir=dataset_dir, - dataset=dataset, - data_preview_btn=data_preview_btn, - preview_count=preview_count, - preview_samples=preview_samples, - close_btn=close_btn, - cutoff_len=cutoff_len, - max_samples=max_samples, - batch_size=batch_size, - predict=predict, - max_new_tokens=max_new_tokens, - top_p=top_p, - temperature=temperature, - cmd_preview_btn=cmd_preview_btn, - start_btn=start_btn, - stop_btn=stop_btn, - output_box=output_box - ) + return elem_dict diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index 6d11c003..fa4cc770 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -5,9 +5,12 @@ from llmtuner.webui.utils import save_model if TYPE_CHECKING: from gradio.components import Component + from llmtuner.webui.engine import Engine -def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: +def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: + elem_dict = dict() + with gr.Row(): save_dir = gr.Textbox() max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) @@ -18,20 +21,23 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component export_btn.click( save_model, [ - top_elems["lang"], - top_elems["model_name"], - top_elems["checkpoints"], - top_elems["finetuning_type"], - top_elems["template"], + engine.manager.get_elem("top.lang"), + engine.manager.get_elem("top.model_name"), + engine.manager.get_elem("top.model_path"), + engine.manager.get_elem("top.checkpoints"), + engine.manager.get_elem("top.finetuning_type"), + engine.manager.get_elem("top.template"), max_shard_size, save_dir ], [info_box] ) - return dict( + elem_dict.update(dict( save_dir=save_dir, max_shard_size=max_shard_size, export_btn=export_btn, info_box=info_box - ) + )) + + return elem_dict diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index e4488a4f..ec935ee6 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -1,53 +1,42 @@ import gradio as gr from typing import TYPE_CHECKING, Dict -from llmtuner.webui.chat import WebChatModel from llmtuner.webui.components.chatbot import create_chat_box if TYPE_CHECKING: from gradio.components import Component + from llmtuner.webui.engine import Engine -def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: +def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + with gr.Row(): load_btn = gr.Button() unload_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False) - chat_model = WebChatModel(lazy_init=True) - chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) + elem_dict.update(dict( + info_box=info_box, load_btn=load_btn, unload_btn=unload_btn + )) + + chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) + elem_dict.update(dict(chat_box=chat_box, **chat_elems)) load_btn.click( - chat_model.load_model, - [ - top_elems["lang"], - top_elems["model_name"], - top_elems["checkpoints"], - top_elems["finetuning_type"], - top_elems["quantization_bit"], - top_elems["template"], - top_elems["system_prompt"], - top_elems["flash_attn"], - top_elems["shift_attn"], - top_elems["rope_scaling"] - ], - [info_box] + engine.chatter.load_model, input_elems, [info_box] ).then( - lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] + lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] ) unload_btn.click( - chat_model.unload_model, [top_elems["lang"]], [info_box] + engine.chatter.unload_model, input_elems, [info_box] ).then( lambda: ([], []), outputs=[chatbot, history] ).then( - lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] + lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] ) - return dict( - info_box=info_box, - load_btn=load_btn, - unload_btn=unload_btn, - **chat_elems - ) + return elem_dict diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 8c4698dd..ec6fb91e 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -3,15 +3,17 @@ from typing import TYPE_CHECKING, Dict from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates -from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config +from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config from llmtuner.webui.utils import can_quantize if TYPE_CHECKING: from gradio.components import Component + from llmtuner.webui.engine import Engine -def create_top() -> Dict[str, "Component"]: +def create_top(engine: "Engine") -> Dict[str, "Component"]: available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] + config = gr.State(value=load_config()) with gr.Row(): lang = gr.Dropdown(choices=["en", "zh"], scale=1) @@ -35,17 +37,21 @@ def create_top() -> Dict[str, "Component"]: shift_attn = gr.Checkbox(value=False) rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none") - lang.change(save_config, [lang, model_name, model_path]) + lang.change( + engine.change_lang, [lang], engine.manager.list_elems(), queue=False + ).then( + save_config, inputs=[config, lang, model_name, model_path] + ) model_name.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] ).then( - get_model_path, [model_name], [model_path] + get_model_path, [config, model_name], [model_path] ).then( get_template, [model_name], [template] ) # do not save config since the below line will save - model_path.change(save_config, [lang, model_name, model_path]) + model_path.change(save_config, inputs=[config, lang, model_name, model_path]) finetuning_type.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] @@ -58,6 +64,7 @@ def create_top() -> Dict[str, "Component"]: ) return dict( + config=config, lang=lang, model_name=model_name, model_path=model_path, diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 5b74034e..88f3b862 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -9,10 +9,13 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot if TYPE_CHECKING: from gradio.components import Component - from llmtuner.webui.runner import Runner + from llmtuner.webui.engine import Engine -def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: +def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + with gr.Row(): training_stage = gr.Dropdown( choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 @@ -21,11 +24,17 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic dataset = gr.Dropdown(multiselect=True, scale=4) data_preview_btn = gr.Button(interactive=False, scale=1) - preview_box, preview_count, preview_samples, close_btn = create_preview_box() - training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) + + input_elems.update({training_stage, dataset_dir, dataset}) + elem_dict.update(dict( + training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn + )) + + preview_box, preview_count, preview_samples, close_btn = create_preview_box() + data_preview_btn.click( get_preview, [dataset_dir, dataset], @@ -33,6 +42,10 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic queue=False ) + elem_dict.update(dict( + preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn + )) + with gr.Row(): cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) learning_rate = gr.Textbox(value="5e-5") @@ -40,6 +53,12 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic max_samples = gr.Textbox(value="100000") compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") + input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type}) + elem_dict.update(dict( + cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs, + max_samples=max_samples, compute_type=compute_type + )) + with gr.Row(): batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) @@ -49,12 +68,23 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic max_grad_norm = gr.Textbox(value="1.0") val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) + input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size}) + elem_dict.update(dict( + batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, + lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size + )) + with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Row(): logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) + input_elems.update({logging_steps, save_steps, warmup_steps}) + elem_dict.update(dict( + advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps + )) + with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Row(): lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) @@ -62,6 +92,15 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic lora_target = gr.Textbox(scale=2) resume_lora_training = gr.Checkbox(value=True, scale=1) + input_elems.update({lora_rank, lora_dropout, lora_target, resume_lora_training}) + elem_dict.update(dict( + lora_tab=lora_tab, + lora_rank=lora_rank, + lora_dropout=lora_dropout, + lora_target=lora_target, + resume_lora_training=resume_lora_training, + )) + with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Row(): dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) @@ -70,11 +109,14 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic refresh_btn.click( list_checkpoint, - [top_elems["model_name"], top_elems["finetuning_type"]], + [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")], [reward_model], queue=False ) + input_elems.update({dpo_beta, reward_model}) + elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn)) + with gr.Row(): cmd_preview_btn = gr.Button() start_btn = gr.Button() @@ -94,90 +136,22 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic with gr.Column(scale=1): loss_viewer = gr.Plot() - input_components = [ - top_elems["lang"], - top_elems["model_name"], - top_elems["checkpoints"], - top_elems["finetuning_type"], - top_elems["quantization_bit"], - top_elems["template"], - top_elems["system_prompt"], - top_elems["flash_attn"], - top_elems["shift_attn"], - top_elems["rope_scaling"], - training_stage, - dataset_dir, - dataset, - cutoff_len, - learning_rate, - num_train_epochs, - max_samples, - compute_type, - batch_size, - gradient_accumulation_steps, - lr_scheduler_type, - max_grad_norm, - val_size, - logging_steps, - save_steps, - warmup_steps, - lora_rank, - lora_dropout, - lora_target, - resume_lora_training, - dpo_beta, - reward_model, - output_dir - ] + input_elems.add(output_dir) + output_elems = [output_box, process_bar] + elem_dict.update(dict( + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, + output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer + )) - output_components = [ - output_box, - process_bar - ] - - cmd_preview_btn.click(runner.preview_train, input_components, output_components) - start_btn.click(runner.run_train, input_components, output_components) - stop_btn.click(runner.set_abort, queue=False) + cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems) + start_btn.click(engine.runner.run_train, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort, queue=False) process_bar.change( - gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False + gen_plot, + [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir], + loss_viewer, + queue=False ) - return dict( - training_stage=training_stage, - dataset_dir=dataset_dir, - dataset=dataset, - data_preview_btn=data_preview_btn, - preview_count=preview_count, - preview_samples=preview_samples, - close_btn=close_btn, - cutoff_len=cutoff_len, - learning_rate=learning_rate, - num_train_epochs=num_train_epochs, - max_samples=max_samples, - compute_type=compute_type, - batch_size=batch_size, - gradient_accumulation_steps=gradient_accumulation_steps, - lr_scheduler_type=lr_scheduler_type, - max_grad_norm=max_grad_norm, - val_size=val_size, - advanced_tab=advanced_tab, - logging_steps=logging_steps, - save_steps=save_steps, - warmup_steps=warmup_steps, - lora_tab=lora_tab, - lora_rank=lora_rank, - lora_dropout=lora_dropout, - lora_target=lora_target, - resume_lora_training=resume_lora_training, - rlhf_tab=rlhf_tab, - dpo_beta=dpo_beta, - reward_model=reward_model, - refresh_btn=refresh_btn, - cmd_preview_btn=cmd_preview_btn, - start_btn=start_btn, - stop_btn=stop_btn, - output_dir=output_dir, - output_box=output_box, - loss_viewer=loss_viewer - ) + return elem_dict diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py new file mode 100644 index 00000000..90beb5e2 --- /dev/null +++ b/src/llmtuner/webui/engine.py @@ -0,0 +1,46 @@ +import gradio as gr +from gradio.components import Component # cannot use TYPE_CHECKING here +from typing import Any, Dict, Generator, List, Optional, Tuple + +from llmtuner.webui.chatter import WebChatModel +from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS +from llmtuner.webui.locales import LOCALES +from llmtuner.webui.manager import Manager +from llmtuner.webui.runner import Runner +from llmtuner.webui.utils import get_time + + +class Engine: + + def __init__(self, init_chat: Optional[bool] = False) -> None: + self.manager: "Manager" = Manager() + self.runner: "Runner" = Runner(self.manager) + self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat)) + + def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]: + lang = config.get("lang", None) or "en" + + resume_dict = { + "top.config": {"value": config}, + "top.lang": {"value": lang}, + "train.dataset": {"choices": list_dataset()["choices"]}, + "eval.dataset": {"choices": list_dataset()["choices"]}, + "infer.chat_box": {"visible": self.chatter.loaded} + } + + if config.get("last_model", None): + resume_dict["top.model_name"] = {"value": config["last_model"]} + resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])} + + yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} + + if self.runner.alive: + pass # TODO: restore training + else: + resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx + + def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: + return { + component: gr.update(**LOCALES[name][lang]) + for elems in self.manager.all_elems.values() for name, component in elems.items() + } diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 3b351f63..85f44040 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -9,65 +9,54 @@ from llmtuner.webui.components import ( create_export_tab, create_chat_box ) -from llmtuner.webui.chat import WebChatModel +from llmtuner.webui.common import load_config, save_config from llmtuner.webui.css import CSS -from llmtuner.webui.manager import Manager -from llmtuner.webui.runner import Runner +from llmtuner.webui.engine import Engine require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") def create_ui() -> gr.Blocks: - runner = Runner() + engine = Engine(init_chat=False) with gr.Blocks(title="Web Tuner", css=CSS) as demo: - top_elems = create_top() + engine.manager.all_elems["top"] = create_top(engine) with gr.Tab("Train"): - train_elems = create_train_tab(top_elems, runner) + engine.manager.all_elems["train"] = create_train_tab(engine) with gr.Tab("Evaluate"): - eval_elems = create_eval_tab(top_elems, runner) + engine.manager.all_elems["eval"] = create_eval_tab(engine) with gr.Tab("Chat"): - infer_elems = create_infer_tab(top_elems) + engine.manager.all_elems["infer"] = create_infer_tab(engine) with gr.Tab("Export"): - export_elems = create_export_tab(top_elems) + engine.manager.all_elems["export"] = create_export_tab(engine) - elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems] - manager = Manager(elem_list) - - demo.load( - manager.gen_label, - [top_elems["lang"]], - [elem for elems in elem_list for elem in elems.values()], - ) - - top_elems["lang"].change( - manager.gen_label, - [top_elems["lang"]], - [elem for elems in elem_list for elem in elems.values()], - queue=False - ) + demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems()) return demo def create_web_demo() -> gr.Blocks: - chat_model = WebChatModel(lazy_init=False) + engine = Engine(init_chat=True) with gr.Blocks(title="Web Demo", css=CSS) as demo: - lang = gr.Dropdown(choices=["en", "zh"], value="en") + lang = gr.Dropdown(choices=["en", "zh"]) + config = gr.State(value=load_config()) + lang.change( + engine.change_lang, [lang], engine.manager.list_elems(), queue=False + ).then( + save_config, inputs=[config, lang] + ) - _, _, _, chat_elems = create_chat_box(chat_model, visible=True) + engine.manager.all_elems["top"] = dict(lang=lang) - manager = Manager([{"lang": lang}, chat_elems]) + _, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True) - demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) - - lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False) + demo.load(engine.resume, [config], engine.manager.list_elems()) return demo diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 93005e52..3bfa5329 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -1,4 +1,8 @@ LOCALES = { + "config": { + "en": {}, + "zh": {} + }, "lang": { "en": { "label": "Lang" @@ -443,6 +447,10 @@ LOCALES = { "label": "保存预测结果" } }, + "chat_box": { + "en": {}, + "zh": {} + }, "load_btn": { "en": { "value": "Load model" diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index 0593657f..ebd28463 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -1,46 +1,36 @@ -import gradio as gr -from gradio.components import Component -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Dict, List -from llmtuner.webui.common import get_model_path, list_dataset, load_config -from llmtuner.webui.locales import LOCALES -from llmtuner.webui.utils import get_time +if TYPE_CHECKING: + from gradio.components import Component class Manager: - def __init__(self, elem_list: List[Dict[str, Component]]): - self.elem_list = elem_list + def __init__(self) -> None: + self.all_elems: Dict[str, Dict[str, "Component"]] = {} - def gen_refresh(self, lang: str) -> Dict[str, Any]: - refresh_dict = { - "dataset": {"choices": list_dataset()["choices"]}, - "output_dir": {"value": get_time()} + def get_elem(self, name: str) -> "Component": + r""" + Example: top.lang, train.dataset + """ + tab_name, elem_name = name.split(".") + return self.all_elems[tab_name][elem_name] + + def get_base_elems(self): + return { + self.all_elems["top"]["config"], + self.all_elems["top"]["lang"], + self.all_elems["top"]["model_name"], + self.all_elems["top"]["model_path"], + self.all_elems["top"]["checkpoints"], + self.all_elems["top"]["finetuning_type"], + self.all_elems["top"]["quantization_bit"], + self.all_elems["top"]["template"], + self.all_elems["top"]["system_prompt"], + self.all_elems["top"]["flash_attn"], + self.all_elems["top"]["shift_attn"], + self.all_elems["top"]["rope_scaling"] } - user_config = load_config() - if not lang: - if user_config.get("lang", None): - lang = user_config["lang"] - else: - lang = "en" - - refresh_dict["lang"] = {"value": lang} - - if user_config.get("last_model", None): - refresh_dict["model_name"] = {"value": user_config["last_model"]} - refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} - - return refresh_dict - - def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING - update_dict = {} - refresh_dict = self.gen_refresh(lang) - - for elems in self.elem_list: - for name, component in elems.items(): - update_dict[component] = gr.update( - **LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {}) - ) - - return update_dict + def list_elems(self) -> List["Component"]: + return [elem for elems in self.all_elems.values() for elem in elems.values()] diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 2af4d4a7..5c629790 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -1,26 +1,32 @@ import os import time import logging -import threading import gradio as gr -from typing import Any, Dict, Generator, List, Tuple +from threading import Thread +from gradio.components import Component # cannot use TYPE_CHECKING here +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple import transformers from transformers.trainer import TRAINING_ARGS_NAME from llmtuner.extras.callbacks import LogCallback -from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES +from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp -from llmtuner.webui.common import get_model_path, get_save_dir, load_config +from llmtuner.webui.common import get_module, get_save_dir from llmtuner.webui.locales import ALERTS from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar +if TYPE_CHECKING: + from llmtuner.webui.manager import Manager + class Runner: - def __init__(self): + def __init__(self, manager: "Manager") -> None: + self.manager = manager + self.thread: "Thread" = None self.aborted = False self.running = False self.logger_handler = LoggerHandler() @@ -28,20 +34,22 @@ class Runner: logging.root.addHandler(self.logger_handler) transformers.logging.add_handler(self.logger_handler) - def set_abort(self): + @property + def alive(self) -> bool: + return self.thread is not None + + def set_abort(self) -> None: self.aborted = True self.running = False - def _initialize( - self, lang: str, model_name: str, dataset: List[str] - ) -> str: + def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str: if self.running: return ALERTS["err_conflict"][lang] if not model_name: return ALERTS["err_no_model"][lang] - if not get_model_path(model_name): + if not model_path: return ALERTS["err_no_path"][lang] if len(dataset) == 0: @@ -52,9 +60,8 @@ class Runner: self.trainer_callback = LogCallback(self) return "" - def _finalize( - self, lang: str, finish_info: str - ) -> str: + def _finalize(self, lang: str, finish_info: str) -> str: + self.thread = None self.running = False torch_gc() if self.aborted: @@ -62,236 +69,171 @@ class Runner: else: return finish_info - def _parse_train_args( - self, - lang: str, - model_name: str, - checkpoints: List[str], - finetuning_type: str, - quantization_bit: str, - template: str, - system_prompt: str, - flash_attn: bool, - shift_attn: bool, - rope_scaling: str, - training_stage: str, - dataset_dir: str, - dataset: List[str], - cutoff_len: int, - learning_rate: str, - num_train_epochs: str, - max_samples: str, - compute_type: str, - batch_size: int, - gradient_accumulation_steps: int, - lr_scheduler_type: str, - max_grad_norm: str, - val_size: float, - logging_steps: int, - save_steps: int, - warmup_steps: int, - lora_rank: int, - lora_dropout: float, - lora_target: str, - resume_lora_training: bool, - dpo_beta: float, - reward_model: str, - output_dir: str - ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: - if checkpoints: - checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) + def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: + get = lambda name: data[self.manager.get_elem(name)] + + if get("top.checkpoints"): + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) else: checkpoint_dir = None - output_dir = get_save_dir(model_name, finetuning_type, output_dir) - - user_config = load_config() - cache_dir = user_config.get("cache_dir", None) + output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) args = dict( - stage=TRAINING_STAGES[training_stage], - model_name_or_path=get_model_path(model_name), + stage=TRAINING_STAGES[get("train.training_stage")], + model_name_or_path=get("top.model_path"), do_train=True, overwrite_cache=False, - cache_dir=cache_dir, + cache_dir=get("top.config").get("cache_dir", None), checkpoint_dir=checkpoint_dir, - finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, - template=template, - system_prompt=system_prompt, - flash_attn=flash_attn, - shift_attn=shift_attn, - rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None, - dataset_dir=dataset_dir, - dataset=",".join(dataset), - cutoff_len=cutoff_len, - learning_rate=float(learning_rate), - num_train_epochs=float(num_train_epochs), - max_samples=int(max_samples), - per_device_train_batch_size=batch_size, - gradient_accumulation_steps=gradient_accumulation_steps, - lr_scheduler_type=lr_scheduler_type, - max_grad_norm=float(max_grad_norm), - logging_steps=logging_steps, - save_steps=save_steps, - warmup_steps=warmup_steps, - lora_rank=lora_rank, - lora_dropout=lora_dropout, - lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), - resume_lora_training=resume_lora_training, + finetuning_type=get("top.finetuning_type"), + quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + template=get("top.template"), + system_prompt=get("top.system_prompt"), + flash_attn=get("top.flash_attn"), + shift_attn=get("top.shift_attn"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + dataset_dir=get("train.dataset_dir"), + dataset=",".join(get("train.dataset")), + cutoff_len=get("train.cutoff_len"), + learning_rate=float(get("train.learning_rate")), + num_train_epochs=float(get("train.num_train_epochs")), + max_samples=int(get("train.max_samples")), + per_device_train_batch_size=get("train.batch_size"), + gradient_accumulation_steps=get("train.gradient_accumulation_steps"), + lr_scheduler_type=get("train.lr_scheduler_type"), + max_grad_norm=float(get("train.max_grad_norm")), + logging_steps=get("train.logging_steps"), + save_steps=get("train.save_steps"), + warmup_steps=get("train.warmup_steps"), + lora_rank=get("train.lora_rank"), + lora_dropout=get("train.lora_dropout"), + lora_target=get("train.lora_target") or get_module(get("top.model_name")), + resume_lora_training=get("train.resume_lora_training"), output_dir=output_dir ) - args[compute_type] = True + args[get("train.compute_type")] = True - if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None: - args["resume_lora_training"] = False + if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]: + args["resume_lora_training"] = (args["quantization_bit"] is not None) if args["quantization_bit"] is not None: args["upcast_layernorm"] = True if args["stage"] == "ppo": - args["reward_model"] = reward_model - val_size = 0 + args["reward_model"] = get("train.reward_model") if args["stage"] == "dpo": - args["dpo_beta"] = dpo_beta + args["dpo_beta"] = get("train.dpo_beta") - if val_size > 1e-6: - args["val_size"] = val_size + if get("train.val_size") > 1e-6 and args["stage"] != "ppo": + args["val_size"] = get("train.val_size") args["evaluation_strategy"] = "steps" - args["eval_steps"] = save_steps + args["eval_steps"] = get("train.save_steps") args["load_best_model_at_end"] = True - return lang, model_name, dataset, output_dir, args + return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args - def _parse_eval_args( - self, - lang: str, - model_name: str, - checkpoints: List[str], - finetuning_type: str, - quantization_bit: str, - template: str, - system_prompt: str, - flash_attn: bool, - shift_attn: bool, - rope_scaling: str, - dataset_dir: str, - dataset: List[str], - cutoff_len: int, - max_samples: str, - batch_size: int, - predict: bool, - max_new_tokens: int, - top_p: float, - temperature: float - ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: - if checkpoints: - checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) - output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints)) + def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: + get = lambda name: data[self.manager.get_elem(name)] + + if get("top.checkpoints"): + checkpoint_dir = ",".join([ + get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") + ]) + output_dir = get_save_dir( + get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints")) + ) else: checkpoint_dir = None - output_dir = get_save_dir(model_name, finetuning_type, "eval_base") - - user_config = load_config() - cache_dir = user_config.get("cache_dir", None) + output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base") args = dict( stage="sft", - model_name_or_path=get_model_path(model_name), + model_name_or_path=get("top.model_path"), do_eval=True, overwrite_cache=False, predict_with_generate=True, - cache_dir=cache_dir, + cache_dir=get("top.config").get("cache_dir", None), checkpoint_dir=checkpoint_dir, - finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, - template=template, - system_prompt=system_prompt, - flash_attn=flash_attn, - shift_attn=shift_attn, - rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None, - dataset_dir=dataset_dir, - dataset=",".join(dataset), - cutoff_len=cutoff_len, - max_samples=int(max_samples), - per_device_eval_batch_size=batch_size, - max_new_tokens=max_new_tokens, - top_p=top_p, - temperature=temperature, - output_dir=output_dir + finetuning_type=get("top.finetuning_type"), + quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + template=get("top.template"), + system_prompt=get("top.system_prompt"), + flash_attn=get("top.flash_attn"), + shift_attn=get("top.shift_attn"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + dataset_dir=get("eval.dataset_dir"), + dataset=",".join(get("eval.dataset")), + cutoff_len=get("eval.cutoff_len"), + max_samples=int(get("eval.max_samples")), + per_device_eval_batch_size=get("eval.batch_size"), + max_new_tokens=get("eval.max_new_tokens"), + top_p=get("eval.top_p"), + temperature=get("eval.temperature"), + output_dir=get("eval.output_dir") ) - if predict: + if get("eval.predict"): args.pop("do_eval", None) args["do_predict"] = True - return lang, model_name, dataset, output_dir, args + return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args - def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - lang, model_name, dataset, _, args = self._parse_train_args(*args) - error = self._initialize(lang, model_name, dataset) + def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + lang, model_name, model_path, dataset, _, args = self._parse_train_args(data) + error = self._initialize(lang, model_name, model_path, dataset) if error: yield error, gr.update(visible=False) else: yield gen_cmd(args), gr.update(visible=False) - def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - lang, model_name, dataset, _, args = self._parse_eval_args(*args) - error = self._initialize(lang, model_name, dataset) + def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data) + error = self._initialize(lang, model_name, model_path, dataset) if error: yield error, gr.update(visible=False) else: yield gen_cmd(args), gr.update(visible=False) - def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - lang, model_name, dataset, output_dir, args = self._parse_train_args(*args) - error = self._initialize(lang, model_name, dataset) + def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + self.prepare(data, self._parse_train_args) + + def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + self.prepare(data, self._parse_eval_args) + + def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + parse_func = self._parse_train_args if is_training else self._parse_eval_args + lang, model_name, model_path, dataset, output_dir, args = parse_func(data) + error = self._initialize(lang, model_name, model_path, dataset) if error: yield error, gr.update(visible=False) - return + else: + self.running = True + run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) + thread = Thread(target=run_exp, kwargs=run_kwargs) + thread.start() + yield self.monitor(lang, output_dir, is_training) - self.running = True - run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) - thread = threading.Thread(target=run_exp, kwargs=run_kwargs) - thread.start() - - while thread.is_alive(): + def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + while self.thread.is_alive(): time.sleep(2) if self.aborted: yield ALERTS["info_aborting"][lang], gr.update(visible=False) else: yield self.logger_handler.log, update_process_bar(self.trainer_callback) - if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): - finish_info = ALERTS["info_finished"][lang] - else: - finish_info = ALERTS["err_failed"][lang] - - yield self._finalize(lang, finish_info), gr.update(visible=False) - - def run_eval(self, *args) -> Generator[str, None, None]: - lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args) - error = self._initialize(lang, model_name, dataset) - if error: - yield error, gr.update(visible=False) - return - - self.running = True - run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) - thread = threading.Thread(target=run_exp, kwargs=run_kwargs) - thread.start() - - while thread.is_alive(): - time.sleep(2) - if self.aborted: - yield ALERTS["info_aborting"][lang], gr.update(visible=False) + if is_training: + if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): + finish_info = ALERTS["info_finished"][lang] else: - yield self.logger_handler.log, update_process_bar(self.trainer_callback) - - if os.path.exists(os.path.join(output_dir, "all_results.json")): - finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) + finish_info = ALERTS["err_failed"][lang] else: - finish_info = ALERTS["err_failed"][lang] + if os.path.exists(os.path.join(output_dir, "all_results.json")): + finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) + else: + finish_info = ALERTS["err_failed"][lang] yield self._finalize(lang, finish_info), gr.update(visible=False) diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 52016378..ef324425 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -8,7 +8,7 @@ from datetime import datetime from llmtuner.extras.ploting import smooth from llmtuner.tuner import export_model -from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG +from llmtuner.webui.common import get_save_dir, DATA_CONFIG from llmtuner.webui.locales import ALERTS if TYPE_CHECKING: @@ -119,6 +119,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl def save_model( lang: str, model_name: str, + model_path: str, checkpoints: List[str], finetuning_type: str, template: str, @@ -129,8 +130,7 @@ def save_model( yield ALERTS["err_no_model"][lang] return - model_name_or_path = get_model_path(model_name) - if not model_name_or_path: + if not model_path: yield ALERTS["err_no_path"][lang] return @@ -138,17 +138,13 @@ def save_model( yield ALERTS["err_no_checkpoint"][lang] return - checkpoint_dir = ",".join( - [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] - ) - if not save_dir: yield ALERTS["err_no_save_dir"][lang] return args = dict( - model_name_or_path=model_name_or_path, - checkpoint_dir=checkpoint_dir, + model_name_or_path=model_path, + checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), finetuning_type=finetuning_type, template=template, output_dir=save_dir