refactor webui
This commit is contained in:
parent
c874e764b8
commit
7ed1fa6fe9
|
@ -1,69 +1,73 @@
|
|||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
|
||||
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
|
||||
self.manager = manager
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if not lazy_init:
|
||||
super().__init__(args)
|
||||
super().__init__()
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
flash_attn: bool,
|
||||
shift_attn: bool,
|
||||
rope_scaling: str
|
||||
) -> Generator[str, None, None]:
|
||||
if self.model is not None:
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
lang = get("top.lang")
|
||||
|
||||
if self.loaded:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
return
|
||||
|
||||
if not model_name:
|
||||
if not get("top.model_name"):
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
if not get("top.model_path"):
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
model_name_or_path=get("top.model_path"),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, lang: str) -> Generator[str, None, None]:
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
lang = get("top.lang")
|
||||
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
|
@ -11,7 +11,7 @@ from transformers.utils import (
|
|||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
|
@ -27,6 +27,7 @@ CKPT_NAMES = [
|
|||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
|
@ -37,7 +38,7 @@ def get_config_path() -> os.PathLike:
|
|||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
def load_config() -> CONFIG_CLASS:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
@ -45,20 +46,24 @@ def load_config() -> Dict[str, Any]:
|
|||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||
def save_config(
|
||||
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
|
||||
) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
config["lang"] = lang or config["lang"]
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
config["last_model"] = model_name
|
||||
config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
def get_model_path(config: Dict[str, Any], model_name: str) -> str:
|
||||
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
|
||||
|
||||
def get_module(model_name: str) -> str:
|
||||
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
|
|
|
@ -4,13 +4,15 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
|||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: "WebChatModel",
|
||||
engine: "Engine",
|
||||
visible: Optional[bool] = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
|
@ -22,14 +24,20 @@ def create_chat_box(
|
|||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
|
||||
gen_kwargs = engine.chatter.generating_args
|
||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
|
||||
elem_dict.update(dict(
|
||||
system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
))
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
|
@ -39,12 +47,4 @@ def create_chat_box(
|
|||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
system=system,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
return chat_box, chatbot, history, elem_dict
|
||||
|
|
|
@ -7,19 +7,28 @@ from llmtuner.webui.utils import can_preview, get_preview
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
|
||||
))
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset],
|
||||
|
@ -27,17 +36,31 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
queue=False
|
||||
)
|
||||
|
||||
elem_dict.update(dict(
|
||||
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature})
|
||||
elem_dict.update(dict(
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
|
@ -49,53 +72,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
input_components = [
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
cutoff_len,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict,
|
||||
max_new_tokens,
|
||||
top_p,
|
||||
temperature
|
||||
]
|
||||
output_elems = [output_box, process_bar]
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box
|
||||
))
|
||||
|
||||
output_components = [
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
|
||||
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
|
||||
start_btn.click(runner.run_eval, input_components, output_components)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
predict=predict,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_box=output_box
|
||||
)
|
||||
return elem_dict
|
||||
|
|
|
@ -5,9 +5,12 @@ from llmtuner.webui.utils import save_model
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
@ -18,20 +21,23 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
|
|||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["template"],
|
||||
engine.manager.get_elem("top.lang"),
|
||||
engine.manager.get_elem("top.model_name"),
|
||||
engine.manager.get_elem("top.model_path"),
|
||||
engine.manager.get_elem("top.checkpoints"),
|
||||
engine.manager.get_elem("top.finetuning_type"),
|
||||
engine.manager.get_elem("top.template"),
|
||||
max_shard_size,
|
||||
save_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
elem_dict.update(dict(
|
||||
save_dir=save_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
)
|
||||
))
|
||||
|
||||
return elem_dict
|
||||
|
|
|
@ -1,53 +1,42 @@
|
|||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
chat_model = WebChatModel(lazy_init=True)
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
elem_dict.update(dict(
|
||||
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
|
||||
))
|
||||
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"]
|
||||
],
|
||||
[info_box]
|
||||
engine.chatter.load_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
chat_model.unload_model, [top_elems["lang"]], [info_box]
|
||||
engine.chatter.unload_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
**chat_elems
|
||||
)
|
||||
return elem_dict
|
||||
|
|
|
@ -3,15 +3,17 @@ from typing import TYPE_CHECKING, Dict
|
|||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
def create_top(engine: "Engine") -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
config = gr.State(value=load_config())
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
||||
|
@ -35,17 +37,21 @@ def create_top() -> Dict[str, "Component"]:
|
|||
shift_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
lang.change(save_config, [lang, model_name, model_path])
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang, model_name, model_path]
|
||||
)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
get_model_path, [config, model_name], [model_path]
|
||||
).then(
|
||||
get_template, [model_name], [template]
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, [lang, model_name, model_path])
|
||||
model_path.change(save_config, inputs=[config, lang, model_name, model_path])
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
|
@ -58,6 +64,7 @@ def create_top() -> Dict[str, "Component"]:
|
|||
)
|
||||
|
||||
return dict(
|
||||
config=config,
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
|
|
@ -9,10 +9,13 @@ from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
|
||||
|
@ -21,11 +24,17 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
|
||||
))
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset],
|
||||
|
@ -33,6 +42,10 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
queue=False
|
||||
)
|
||||
|
||||
elem_dict.update(dict(
|
||||
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
|
@ -40,6 +53,12 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples, compute_type=compute_type
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
|
@ -49,12 +68,23 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
||||
elem_dict.update(dict(
|
||||
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
||||
))
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps})
|
||||
elem_dict.update(dict(
|
||||
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps
|
||||
))
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
|
@ -62,6 +92,15 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
lora_target = gr.Textbox(scale=2)
|
||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, resume_lora_training})
|
||||
elem_dict.update(dict(
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
resume_lora_training=resume_lora_training,
|
||||
))
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
|
@ -70,11 +109,14 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
|
||||
refresh_btn.click(
|
||||
list_checkpoint,
|
||||
[top_elems["model_name"], top_elems["finetuning_type"]],
|
||||
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False
|
||||
)
|
||||
|
||||
input_elems.update({dpo_beta, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
|
@ -94,90 +136,22 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_components = [
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"],
|
||||
training_stage,
|
||||
dataset_dir,
|
||||
dataset,
|
||||
cutoff_len,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
compute_type,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
val_size,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
resume_lora_training,
|
||||
dpo_beta,
|
||||
reward_model,
|
||||
output_dir
|
||||
]
|
||||
input_elems.add(output_dir)
|
||||
output_elems = [output_box, process_bar]
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
||||
output_dir=output_dir, output_box=output_box, loss_viewer=loss_viewer
|
||||
))
|
||||
|
||||
output_components = [
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
|
||||
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
|
||||
start_btn.click(runner.run_train, input_components, output_components)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
|
||||
process_bar.change(
|
||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||
gen_plot,
|
||||
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
|
||||
loss_viewer,
|
||||
queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
training_stage=training_stage,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
cutoff_len=cutoff_len,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
compute_type=compute_type,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=max_grad_norm,
|
||||
val_size=val_size,
|
||||
advanced_tab=advanced_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
resume_lora_training=resume_lora_training,
|
||||
rlhf_tab=rlhf_tab,
|
||||
dpo_beta=dpo_beta,
|
||||
reward_model=reward_model,
|
||||
refresh_btn=refresh_btn,
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer
|
||||
)
|
||||
return elem_dict
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.webui.chatter import WebChatModel
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import get_time
|
||||
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, init_chat: Optional[bool] = False) -> None:
|
||||
self.manager: "Manager" = Manager()
|
||||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat))
|
||||
|
||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
lang = config.get("lang", None) or "en"
|
||||
|
||||
resume_dict = {
|
||||
"top.config": {"value": config},
|
||||
"top.lang": {"value": lang},
|
||||
"train.dataset": {"choices": list_dataset()["choices"]},
|
||||
"eval.dataset": {"choices": list_dataset()["choices"]},
|
||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
||||
}
|
||||
|
||||
if config.get("last_model", None):
|
||||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
||||
|
||||
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
if self.runner.alive:
|
||||
pass # TODO: restore training
|
||||
else:
|
||||
resume_dict = {"train.output_dir": {"value": get_time()}} # TODO: xxx
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
component: gr.update(**LOCALES[name][lang])
|
||||
for elems in self.manager.all_elems.values() for name, component in elems.items()
|
||||
}
|
|
@ -9,65 +9,54 @@ from llmtuner.webui.components import (
|
|||
create_export_tab,
|
||||
create_chat_box
|
||||
)
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.common import load_config, save_config
|
||||
from llmtuner.webui.css import CSS
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
runner = Runner()
|
||||
engine = Engine(init_chat=False)
|
||||
|
||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
top_elems = create_top()
|
||||
engine.manager.all_elems["top"] = create_top(engine)
|
||||
|
||||
with gr.Tab("Train"):
|
||||
train_elems = create_train_tab(top_elems, runner)
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
||||
with gr.Tab("Evaluate"):
|
||||
eval_elems = create_eval_tab(top_elems, runner)
|
||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
infer_elems = create_infer_tab(top_elems)
|
||||
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
||||
|
||||
with gr.Tab("Export"):
|
||||
export_elems = create_export_tab(top_elems)
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
|
||||
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
|
||||
manager = Manager(elem_list)
|
||||
|
||||
demo.load(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
)
|
||||
|
||||
top_elems["lang"].change(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
queue=False
|
||||
)
|
||||
demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems())
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def create_web_demo() -> gr.Blocks:
|
||||
chat_model = WebChatModel(lazy_init=False)
|
||||
engine = Engine(init_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
config = gr.State(value=load_config())
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang]
|
||||
)
|
||||
|
||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||
|
||||
manager = Manager([{"lang": lang}, chat_elems])
|
||||
_, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
|
||||
|
||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
|
||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||
|
||||
return demo
|
||||
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
LOCALES = {
|
||||
"config": {
|
||||
"en": {},
|
||||
"zh": {}
|
||||
},
|
||||
"lang": {
|
||||
"en": {
|
||||
"label": "Lang"
|
||||
|
@ -443,6 +447,10 @@ LOCALES = {
|
|||
"label": "保存预测结果"
|
||||
}
|
||||
},
|
||||
"chat_box": {
|
||||
"en": {},
|
||||
"zh": {}
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model"
|
||||
|
|
|
@ -1,46 +1,36 @@
|
|||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
from typing import Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.utils import get_time
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Manager:
|
||||
|
||||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||
self.elem_list = elem_list
|
||||
def __init__(self) -> None:
|
||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||
|
||||
def gen_refresh(self, lang: str) -> Dict[str, Any]:
|
||||
refresh_dict = {
|
||||
"dataset": {"choices": list_dataset()["choices"]},
|
||||
"output_dir": {"value": get_time()}
|
||||
def get_elem(self, name: str) -> "Component":
|
||||
r"""
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
tab_name, elem_name = name.split(".")
|
||||
return self.all_elems[tab_name][elem_name]
|
||||
|
||||
def get_base_elems(self):
|
||||
return {
|
||||
self.all_elems["top"]["config"],
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
self.all_elems["top"]["model_path"],
|
||||
self.all_elems["top"]["checkpoints"],
|
||||
self.all_elems["top"]["finetuning_type"],
|
||||
self.all_elems["top"]["quantization_bit"],
|
||||
self.all_elems["top"]["template"],
|
||||
self.all_elems["top"]["system_prompt"],
|
||||
self.all_elems["top"]["flash_attn"],
|
||||
self.all_elems["top"]["shift_attn"],
|
||||
self.all_elems["top"]["rope_scaling"]
|
||||
}
|
||||
|
||||
user_config = load_config()
|
||||
if not lang:
|
||||
if user_config.get("lang", None):
|
||||
lang = user_config["lang"]
|
||||
else:
|
||||
lang = "en"
|
||||
|
||||
refresh_dict["lang"] = {"value": lang}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
return refresh_dict
|
||||
|
||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||
update_dict = {}
|
||||
refresh_dict = self.gen_refresh(lang)
|
||||
|
||||
for elems in self.elem_list:
|
||||
for name, component in elems.items():
|
||||
update_dict[component] = gr.update(
|
||||
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
|
||||
)
|
||||
|
||||
return update_dict
|
||||
def list_elems(self) -> List["Component"]:
|
||||
return [elem for elems in self.all_elems.values() for elem in elems.values()]
|
||||
|
|
|
@ -1,26 +1,32 @@
|
|||
import os
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from threading import Thread
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, load_config
|
||||
from llmtuner.webui.common import get_module, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, manager: "Manager") -> None:
|
||||
self.manager = manager
|
||||
self.thread: "Thread" = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.logger_handler = LoggerHandler()
|
||||
|
@ -28,20 +34,22 @@ class Runner:
|
|||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def set_abort(self):
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def _initialize(
|
||||
self, lang: str, model_name: str, dataset: List[str]
|
||||
) -> str:
|
||||
def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
|
||||
if self.running:
|
||||
return ALERTS["err_conflict"][lang]
|
||||
|
||||
if not model_name:
|
||||
return ALERTS["err_no_model"][lang]
|
||||
|
||||
if not get_model_path(model_name):
|
||||
if not model_path:
|
||||
return ALERTS["err_no_path"][lang]
|
||||
|
||||
if len(dataset) == 0:
|
||||
|
@ -52,9 +60,8 @@ class Runner:
|
|||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(
|
||||
self, lang: str, finish_info: str
|
||||
) -> str:
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
self.thread = None
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
|
@ -62,236 +69,171 @@ class Runner:
|
|||
else:
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
flash_attn: bool,
|
||||
shift_attn: bool,
|
||||
rope_scaling: str,
|
||||
training_stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
cutoff_len: int,
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
compute_type: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
max_grad_norm: str,
|
||||
val_size: float,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
dpo_beta: float,
|
||||
reward_model: str,
|
||||
output_dir: str
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
user_config = load_config()
|
||||
cache_dir = user_config.get("cache_dir", None)
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[training_stage],
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_train=True,
|
||||
overwrite_cache=False,
|
||||
cache_dir=cache_dir,
|
||||
cache_dir=get("top.config").get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
cutoff_len=cutoff_len,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=float(max_grad_norm),
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
resume_lora_training=resume_lora_training,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
learning_rate=float(get("train.learning_rate")),
|
||||
num_train_epochs=float(get("train.num_train_epochs")),
|
||||
max_samples=int(get("train.max_samples")),
|
||||
per_device_train_batch_size=get("train.batch_size"),
|
||||
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
|
||||
lr_scheduler_type=get("train.lr_scheduler_type"),
|
||||
max_grad_norm=float(get("train.max_grad_norm")),
|
||||
logging_steps=get("train.logging_steps"),
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
resume_lora_training=get("train.resume_lora_training"),
|
||||
output_dir=output_dir
|
||||
)
|
||||
args[compute_type] = True
|
||||
args[get("train.compute_type")] = True
|
||||
|
||||
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None:
|
||||
args["resume_lora_training"] = False
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["resume_lora_training"] = (args["quantization_bit"] is not None)
|
||||
|
||||
if args["quantization_bit"] is not None:
|
||||
args["upcast_layernorm"] = True
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = reward_model
|
||||
val_size = 0
|
||||
args["reward_model"] = get("train.reward_model")
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = dpo_beta
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
|
||||
if val_size > 1e-6:
|
||||
args["val_size"] = val_size
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = save_steps
|
||||
args["eval_steps"] = get("train.save_steps")
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
return lang, model_name, dataset, output_dir, args
|
||||
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
|
||||
|
||||
def _parse_eval_args(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
flash_attn: bool,
|
||||
shift_attn: bool,
|
||||
rope_scaling: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
cutoff_len: int,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_base")
|
||||
|
||||
user_config = load_config()
|
||||
cache_dir = user_config.get("cache_dir", None)
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_eval=True,
|
||||
overwrite_cache=False,
|
||||
predict_with_generate=True,
|
||||
cache_dir=cache_dir,
|
||||
cache_dir=get("top.config").get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
output_dir=output_dir
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get("eval.output_dir")
|
||||
)
|
||||
|
||||
if predict:
|
||||
if get("eval.predict"):
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
return lang, model_name, dataset, output_dir, args
|
||||
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
|
||||
|
||||
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, dataset, _, args = self._parse_train_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, model_path, dataset, _, args = self._parse_train_args(data)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.prepare(data, self._parse_train_args)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.prepare(data, self._parse_eval_args)
|
||||
|
||||
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if is_training else self._parse_eval_args
|
||||
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
else:
|
||||
self.running = True
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
yield self.monitor(lang, output_dir, is_training)
|
||||
|
||||
self.running = True
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
def run_eval(self, *args) -> Generator[str, None, None]:
|
||||
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
|
||||
self.running = True
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
if is_training:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
|
|
@ -8,7 +8,7 @@ from datetime import datetime
|
|||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -119,6 +119,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
|||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
|
@ -129,8 +130,7 @@ def save_model(
|
|||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
if not model_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
|
@ -138,17 +138,13 @@ def save_model(
|
|||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
checkpoint_dir = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
model_name_or_path=model_path,
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
output_dir=save_dir
|
||||
|
|
Loading…
Reference in New Issue