refactor webui

This commit is contained in:
hiyouga 2023-10-15 03:06:21 +08:00
parent c874e764b8
commit 7ed1fa6fe9
14 changed files with 440 additions and 501 deletions

View File

@ -1,69 +1,73 @@
from typing import Any, Dict, Generator, List, Optional, Tuple
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
class WebChatModel(ChatModel):
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
self.manager = manager
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init:
super().__init__(args)
super().__init__()
def load_model(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str
) -> Generator[str, None, None]:
if self.model is not None:
@property
def loaded(self) -> bool:
return self.model is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)]
lang = get("top.lang")
if self.loaded:
yield ALERTS["err_exists"][lang]
return
if not model_name:
if not get("top.model_name"):
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
if not get("top.model_path"):
yield ALERTS["err_no_path"][lang]
return
if checkpoints:
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
else:
checkpoint_dir = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_name_or_path,
model_name_or_path=get("top.model_path"),
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
system_prompt=get("top.system_prompt"),
flash_attn=get("top.flash_attn"),
shift_attn=get("top.shift_attn"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
)
super().__init__(args)
yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str) -> Generator[str, None, None]:
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)]
lang = get("top.lang")
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None

View File

@ -1,7 +1,7 @@
import os
import json
import gradio as gr
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
@ -11,7 +11,7 @@ from transformers.utils import (
ADAPTER_SAFE_WEIGHTS_NAME
)
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
DEFAULT_CACHE_DIR = "cache"
@ -27,6 +27,7 @@ CKPT_NAMES = [
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
def get_save_dir(*args) -> os.PathLike:
@ -37,7 +38,7 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
def load_config() -> CONFIG_CLASS:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
@ -45,20 +46,24 @@ def load_config() -> Dict[str, Any]:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(lang: str, model_name: str, model_path: str) -> None:
def save_config(
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
config["lang"] = lang or config["lang"]
if model_name:
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
config["last_model"] = model_name
config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)
json.dump(config, f, indent=2, ensure_ascii=False)
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_model_path(config: Dict[str, Any], model_name: str) -> str:
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
def get_template(model_name: str) -> str:

View File

@ -4,13 +4,15 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.engine import Engine
def create_chat_box(
chat_model: "WebChatModel",
engine: "Engine",
visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
elem_dict = dict()
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
@ -22,14 +24,20 @@ def create_chat_box(
with gr.Column(scale=1):
clear_btn = gr.Button()
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
gen_kwargs = engine.chatter.generating_args
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
elem_dict.update(dict(
system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn,
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
))
history = gr.State([])
submit_btn.click(
chat_model.predict,
engine.chatter.predict,
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
@ -39,12 +47,4 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
)
return chat_box, chatbot, history, elem_dict

View File

@ -7,19 +7,28 @@ from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
from llmtuner.webui.engine import Engine
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(
dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
@ -27,17 +36,31 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
queue=False
)
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
))
with gr.Row():
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
input_elems.update({max_new_tokens, top_p, temperature})
elem_dict.update(dict(
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
))
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
@ -49,53 +72,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Box():
output_box = gr.Markdown()
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
dataset_dir,
dataset,
cutoff_len,
max_samples,
batch_size,
predict,
max_new_tokens,
top_p,
temperature
]
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box
))
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
cutoff_len=cutoff_len,
max_samples=max_samples,
batch_size=batch_size,
predict=predict,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box
)
return elem_dict

View File

@ -5,9 +5,12 @@ from llmtuner.webui.utils import save_model
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict()
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
@ -18,20 +21,23 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
export_btn.click(
save_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["template"],
engine.manager.get_elem("top.lang"),
engine.manager.get_elem("top.model_name"),
engine.manager.get_elem("top.model_path"),
engine.manager.get_elem("top.checkpoints"),
engine.manager.get_elem("top.finetuning_type"),
engine.manager.get_elem("top.template"),
max_shard_size,
save_dir
],
[info_box]
)
return dict(
elem_dict.update(dict(
save_dir=save_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
)
))
return elem_dict

View File

@ -1,53 +1,42 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
chat_model = WebChatModel(lazy_init=True)
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
elem_dict.update(dict(
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
))
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
load_btn.click(
chat_model.load_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"]
],
[info_box]
engine.chatter.load_model, input_elems, [info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
unload_btn.click(
chat_model.unload_model, [top_elems["lang"]], [info_box]
engine.chatter.unload_model, input_elems, [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
)
return dict(
info_box=info_box,
load_btn=load_btn,
unload_btn=unload_btn,
**chat_elems
)
return elem_dict

View File

@ -3,15 +3,17 @@ from typing import TYPE_CHECKING, Dict
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_top() -> Dict[str, "Component"]:
def create_top(engine: "Engine") -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
config = gr.State(value=load_config())
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
@ -35,17 +37,21 @@ def create_top() -> Dict[str, "Component"]:
shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
lang.change(save_config, [lang, model_name, model_path])
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
get_model_path, [config, model_name], [model_path]
).then(
get_template, [model_name], [template]
) # do not save config since the below line will save
model_path.change(save_config, [lang, model_name, model_path])
model_path.change(save_config, inputs=[config, lang, model_name, model_path])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
@ -58,6 +64,7 @@ def create_top() -> Dict[str, "Component"]:
)
return dict(
config=config,
lang=lang,
model_name=model_name,
model_path=model_path,

View File

@ -9,10 +9,13 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
from llmtuner.webui.engine import Engine
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
@ -21,11 +24,17 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
@ -33,6 +42,10 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
queue=False
)
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
learning_rate = gr.Textbox(value="5e-5")
@ -40,6 +53,12 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
max_samples = gr.Textbox(value="100000")
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict(
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
max_samples=max_samples, compute_type=compute_type
))
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
@ -49,12 +68,23 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
elem_dict.update(dict(
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
))
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
input_elems.update({logging_steps, save_steps, warmup_steps})
elem_dict.update(dict(
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps
))
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
@ -62,6 +92,15 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, resume_lora_training})
elem_dict.update(dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
))
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
@ -70,11 +109,14 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
refresh_btn.click(
list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]],
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")],
[reward_model],
queue=False
)
input_elems.update({dpo_beta, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
@ -94,90 +136,22 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
training_stage,
dataset_dir,
dataset,
cutoff_len,
learning_rate,
num_train_epochs,
max_samples,
compute_type,
batch_size,
gradient_accumulation_steps,
lr_scheduler_type,
max_grad_norm,
val_size,
logging_steps,
save_steps,
warmup_steps,
lora_rank,
lora_dropout,
lora_target,
resume_lora_training,
dpo_beta,
reward_model,
output_dir
]
input_elems.add(output_dir)
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer
))
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
gen_plot,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
loss_viewer,
queue=False
)
return dict(
training_stage=training_stage,
dataset_dir=dataset_dir,
dataset=dataset,
data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
cutoff_len=cutoff_len,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
compute_type=compute_type,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)
return elem_dict

View File

@ -0,0 +1,46 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, List, Optional, Tuple
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import get_time
class Engine:
def __init__(self, init_chat: Optional[bool] = False) -> None:
self.manager: "Manager" = Manager()
self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat))
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en"
resume_dict = {
"top.config": {"value": config},
"top.lang": {"value": lang},
"train.dataset": {"choices": list_dataset()["choices"]},
"eval.dataset": {"choices": list_dataset()["choices"]},
"infer.chat_box": {"visible": self.chatter.loaded}
}
if config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
if self.runner.alive:
pass # TODO: restore training
else:
resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return {
component: gr.update(**LOCALES[name][lang])
for elems in self.manager.all_elems.values() for name, component in elems.items()
}

View File

@ -9,65 +9,54 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.common import load_config, save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
from llmtuner.webui.engine import Engine
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks:
runner = Runner()
engine = Engine(init_chat=False)
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
engine.manager.all_elems["top"] = create_top(engine)
with gr.Tab("Train"):
train_elems = create_train_tab(top_elems, runner)
engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems)
engine.manager.all_elems["infer"] = create_infer_tab(engine)
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
engine.manager.all_elems["export"] = create_export_tab(engine)
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
top_elems["lang"].change(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
queue=False
)
demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems())
return demo
def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
engine = Engine(init_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
lang = gr.Dropdown(choices=["en", "zh"])
config = gr.State(value=load_config())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
engine.manager.all_elems["top"] = dict(lang=lang)
manager = Manager([{"lang": lang}, chat_elems])
_, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
demo.load(engine.resume, [config], engine.manager.list_elems())
return demo

View File

@ -1,4 +1,8 @@
LOCALES = {
"config": {
"en": {},
"zh": {}
},
"lang": {
"en": {
"label": "Lang"
@ -443,6 +447,10 @@ LOCALES = {
"label": "保存预测结果"
}
},
"chat_box": {
"en": {},
"zh": {}
},
"load_btn": {
"en": {
"value": "Load model"

View File

@ -1,46 +1,36 @@
import gradio as gr
from gradio.components import Component
from typing import Any, Dict, List
from typing import TYPE_CHECKING, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.utils import get_time
if TYPE_CHECKING:
from gradio.components import Component
class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
def __init__(self) -> None:
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
def gen_refresh(self, lang: str) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
def get_elem(self, name: str) -> "Component":
r"""
Example: top.lang, train.dataset
"""
tab_name, elem_name = name.split(".")
return self.all_elems[tab_name][elem_name]
def get_base_elems(self):
return {
self.all_elems["top"]["config"],
self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"],
self.all_elems["top"]["model_path"],
self.all_elems["top"]["checkpoints"],
self.all_elems["top"]["finetuning_type"],
self.all_elems["top"]["quantization_bit"],
self.all_elems["top"]["template"],
self.all_elems["top"]["system_prompt"],
self.all_elems["top"]["flash_attn"],
self.all_elems["top"]["shift_attn"],
self.all_elems["top"]["rope_scaling"]
}
user_config = load_config()
if not lang:
if user_config.get("lang", None):
lang = user_config["lang"]
else:
lang = "en"
refresh_dict["lang"] = {"value": lang}
if user_config.get("last_model", None):
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {}
refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
)
return update_dict
def list_elems(self) -> List["Component"]:
return [elem for elems in self.all_elems.values() for elem in elems.values()]

View File

@ -1,26 +1,32 @@
import os
import time
import logging
import threading
import gradio as gr
from typing import Any, Dict, Generator, List, Tuple
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir, load_config
from llmtuner.webui.common import get_module, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
class Runner:
def __init__(self):
def __init__(self, manager: "Manager") -> None:
self.manager = manager
self.thread: "Thread" = None
self.aborted = False
self.running = False
self.logger_handler = LoggerHandler()
@ -28,20 +34,22 @@ class Runner:
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def set_abort(self):
@property
def alive(self) -> bool:
return self.thread is not None
def set_abort(self) -> None:
self.aborted = True
self.running = False
def _initialize(
self, lang: str, model_name: str, dataset: List[str]
) -> str:
def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
if self.running:
return ALERTS["err_conflict"][lang]
if not model_name:
return ALERTS["err_no_model"][lang]
if not get_model_path(model_name):
if not model_path:
return ALERTS["err_no_path"][lang]
if len(dataset) == 0:
@ -52,9 +60,8 @@ class Runner:
self.trainer_callback = LogCallback(self)
return ""
def _finalize(
self, lang: str, finish_info: str
) -> str:
def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None
self.running = False
torch_gc()
if self.aborted:
@ -62,236 +69,171 @@ class Runner:
else:
return finish_info
def _parse_train_args(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
cutoff_len: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
compute_type: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
dpo_beta: float,
reward_model: str,
output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
else:
checkpoint_dir = None
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
user_config = load_config()
cache_dir = user_config.get("cache_dir", None)
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
args = dict(
stage=TRAINING_STAGES[training_stage],
model_name_or_path=get_model_path(model_name),
stage=TRAINING_STAGES[get("train.training_stage")],
model_name_or_path=get("top.model_path"),
do_train=True,
overwrite_cache=False,
cache_dir=cache_dir,
cache_dir=get("top.config").get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
cutoff_len=cutoff_len,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=float(max_grad_norm),
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
system_prompt=get("top.system_prompt"),
flash_attn=get("top.flash_attn"),
shift_attn=get("top.shift_attn"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
learning_rate=float(get("train.learning_rate")),
num_train_epochs=float(get("train.num_train_epochs")),
max_samples=int(get("train.max_samples")),
per_device_train_batch_size=get("train.batch_size"),
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
lr_scheduler_type=get("train.lr_scheduler_type"),
max_grad_norm=float(get("train.max_grad_norm")),
logging_steps=get("train.logging_steps"),
save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"),
lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
resume_lora_training=get("train.resume_lora_training"),
output_dir=output_dir
)
args[compute_type] = True
args[get("train.compute_type")] = True
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None:
args["resume_lora_training"] = False
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["resume_lora_training"] = (args["quantization_bit"] is not None)
if args["quantization_bit"] is not None:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = reward_model
val_size = 0
args["reward_model"] = get("train.reward_model")
if args["stage"] == "dpo":
args["dpo_beta"] = dpo_beta
args["dpo_beta"] = get("train.dpo_beta")
if val_size > 1e-6:
args["val_size"] = val_size
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["eval_steps"] = get("train.save_steps")
args["load_best_model_at_end"] = True
return lang, model_name, dataset, output_dir, args
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
def _parse_eval_args(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
dataset_dir: str,
dataset: List[str],
cutoff_len: int,
max_samples: str,
batch_size: int,
predict: bool,
max_new_tokens: int,
top_p: float,
temperature: float
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
else:
checkpoint_dir = None
output_dir = get_save_dir(model_name, finetuning_type, "eval_base")
user_config = load_config()
cache_dir = user_config.get("cache_dir", None)
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
args = dict(
stage="sft",
model_name_or_path=get_model_path(model_name),
model_name_or_path=get("top.model_path"),
do_eval=True,
overwrite_cache=False,
predict_with_generate=True,
cache_dir=cache_dir,
cache_dir=get("top.config").get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
cutoff_len=cutoff_len,
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
output_dir=output_dir
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
system_prompt=get("top.system_prompt"),
flash_attn=get("top.flash_attn"),
shift_attn=get("top.shift_attn"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),
max_samples=int(get("eval.max_samples")),
per_device_eval_batch_size=get("eval.batch_size"),
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get("eval.output_dir")
)
if predict:
if get("eval.predict"):
args.pop("do_eval", None)
args["do_predict"] = True
return lang, model_name, dataset, output_dir, args
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, model_path, dataset, _, args = self._parse_train_args(data)
error = self._initialize(lang, model_name, model_path, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data)
error = self._initialize(lang, model_name, model_path, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.prepare(data, self._parse_train_args)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.prepare(data, self._parse_eval_args)
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if is_training else self._parse_eval_args
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
error = self._initialize(lang, model_name, model_path, dataset)
if error:
yield error, gr.update(visible=False)
return
else:
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
yield self.monitor(lang, output_dir, is_training)
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
while self.thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
if is_training:
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
finish_info = ALERTS["err_failed"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)

View File

@ -8,7 +8,7 @@ from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
@ -119,6 +119,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
def save_model(
lang: str,
model_name: str,
model_path: str,
checkpoints: List[str],
finetuning_type: str,
template: str,
@ -129,8 +130,7 @@ def save_model(
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
if not model_path:
yield ALERTS["err_no_path"][lang]
return
@ -138,17 +138,13 @@ def save_model(
yield ALERTS["err_no_checkpoint"][lang]
return
checkpoint_dir = ",".join(
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
return
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
model_name_or_path=model_path,
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
finetuning_type=finetuning_type,
template=template,
output_dir=save_dir