support save args in webui #2807 #3046

some ideas are borrowed from @marko1616
This commit is contained in:
hiyouga 2024-03-30 23:09:12 +08:00
parent 257f643a74
commit 7a086ed333
9 changed files with 219 additions and 80 deletions

View File

@ -66,6 +66,7 @@ def check_dependencies() -> None:
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:

View File

@ -20,6 +20,7 @@ from ..extras.misc import use_modelscope
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
@ -33,6 +34,10 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def get_save_path(config_path: str) -> os.PathLike:
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
@ -52,6 +57,22 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
json.dump(user_config, f, indent=2, ensure_ascii=False)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
try:
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
return str(get_save_path(config_path))
def get_model_path(model_name: str) -> str:
user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))

View File

@ -46,8 +46,8 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)

View File

@ -27,8 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@ -127,19 +125,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
create_new_adapter = gr.Checkbox()
with gr.Row():
use_rslora = gr.Checkbox(scale=1)
use_dora = gr.Checkbox(scale=1)
create_new_adapter = gr.Checkbox(scale=1)
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
input_elems.update(
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
{
lora_rank,
lora_alpha,
lora_dropout,
loraplus_lr_ratio,
create_new_adapter,
use_rslora,
use_dora,
lora_target,
additional_target,
}
)
elem_dict.update(
dict(
@ -147,10 +156,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_target=lora_target,
loraplus_lr_ratio=loraplus_lr_ratio,
create_new_adapter=create_new_adapter,
use_rslora=use_rslora,
use_dora=use_dora,
create_new_adapter=create_new_adapter,
lora_target=lora_target,
additional_target=additional_target,
)
)
@ -161,13 +171,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
@ -177,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
galore_target = gr.Textbox(value="mlp,attn", scale=3)
galore_target = gr.Textbox(value="all", scale=3)
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
@ -193,13 +196,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
arg_save_btn = gr.Button()
arg_load_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox()
config_path = gr.Textbox()
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
@ -211,20 +217,38 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_elems.add(output_dir)
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems),
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
arg_save_btn=arg_save_btn,
arg_load_btn=arg_load_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
config_path=config_path,
resume_btn=resume_btn,
process_bar=process_bar,
output_box=output_box,

View File

@ -38,8 +38,9 @@ class Engine:
if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset().choices}
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
init_dict["train.output_dir"] = {"value": "train_" + get_time()}
init_dict["eval.output_dir"] = {"value": "eval_" + get_time()}
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}

View File

@ -1,5 +1,4 @@
import gradio as gr
from transformers.utils.versions import require_version
from .common import save_config
from .components import (
@ -14,9 +13,6 @@ from .css import CSS
from .engine import Engine
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
@ -29,21 +25,21 @@ def create_ui(demo_mode: bool = False) -> gr.Blocks:
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.add_elem_dict("top", create_top())
engine.manager.add_elems("top", create_top())
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
with gr.Tab("Train"):
engine.manager.add_elem_dict("train", create_train_tab(engine))
engine.manager.add_elems("train", create_train_tab(engine))
with gr.Tab("Evaluate & Predict"):
engine.manager.add_elem_dict("eval", create_eval_tab(engine))
engine.manager.add_elems("eval", create_eval_tab(engine))
with gr.Tab("Chat"):
engine.manager.add_elem_dict("infer", create_infer_tab(engine))
engine.manager.add_elems("infer", create_infer_tab(engine))
if not demo_mode:
with gr.Tab("Export"):
engine.manager.add_elem_dict("export", create_export_tab(engine))
engine.manager.add_elems("export", create_export_tab(engine))
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
@ -57,10 +53,10 @@ def create_web_demo() -> gr.Blocks:
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elem_dict("top", dict(lang=lang))
engine.manager.add_elems("top", dict(lang=lang))
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elem_dict("infer", dict(chat_box=chat_box, **chat_elems))
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)

View File

@ -628,18 +628,32 @@ LOCALES = {
"info": "LoRA 权重随机丢弃的概率。",
},
},
"lora_target": {
"loraplus_lr_ratio": {
"en": {
"label": "LoRA modules (optional)",
"info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
"label": "LoRA+ LR ratio",
"info": "The LR ratio of the B matrices in LoRA.",
},
"ru": {
"label": "Модули LoRA (опционально)",
"info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
"label": "LoRA+ LR коэффициент",
"info": "Коэффициент LR матриц B в LoRA.",
},
"zh": {
"label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
"label": "LoRA+ 学习率比例",
"info": "LoRA+ 中 B 矩阵的学习率倍数。",
},
},
"create_new_adapter": {
"en": {
"label": "Create new adapter",
"info": "Create a new adapter with randomly initialized weight upon the existing one.",
},
"ru": {
"label": "Создать новый адаптер",
"info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
},
"zh": {
"label": "新建适配器",
"info": "在现有的适配器上创建一个随机初始化后的新适配器。",
},
},
"use_rslora": {
@ -670,18 +684,18 @@ LOCALES = {
"info": "使用权重分解的 LoRA。",
},
},
"create_new_adapter": {
"lora_target": {
"en": {
"label": "Create new adapter",
"info": "Create a new adapter with randomly initialized weight upon the existing one.",
"label": "LoRA modules (optional)",
"info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
},
"ru": {
"label": "Создать новый адаптер",
"info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
"label": "Модули LoRA (опционально)",
"info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
},
"zh": {
"label": "新建适配器",
"info": "在现有的适配器上创建一个随机初始化后的新适配器",
"label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称",
},
},
"additional_target": {
@ -849,6 +863,28 @@ LOCALES = {
"value": "预览命令",
},
},
"arg_save_btn": {
"en": {
"value": "Save arguments",
},
"ru": {
"value": "Сохранить аргументы",
},
"zh": {
"value": "保存训练参数",
},
},
"arg_load_btn": {
"en": {
"value": "Load arguments",
},
"ru": {
"value": "Загрузить аргументы",
},
"zh": {
"value": "载入训练参数",
},
},
"start_btn": {
"en": {
"value": "Start",
@ -885,6 +921,20 @@ LOCALES = {
"info": "保存结果的路径。",
},
},
"config_path": {
"en": {
"label": "Config path",
"info": "Path to config saving arguments.",
},
"ru": {
"label": "Путь к конфигурации",
"info": "Путь для сохранения аргументов конфигурации.",
},
"zh": {
"label": "配置路径",
"info": "保存训练参数的配置文件路径。",
},
},
"output_box": {
"en": {
"value": "Ready.",
@ -1236,6 +1286,11 @@ ALERTS = {
"ru": "Неверная схема JSON.",
"zh": "Json 格式错误。",
},
"err_config_not_found": {
"en": "Config file is not found.",
"ru": "Файл конфигурации не найден.",
"zh": "未找到配置文件。",
},
"warn_no_cuda": {
"en": "CUDA environment was not detected.",
"ru": "Среда CUDA не обнаружена.",
@ -1256,6 +1311,11 @@ ALERTS = {
"ru": "Завершено.",
"zh": "训练完毕。",
},
"info_config_saved": {
"en": "Arguments have been saved at: ",
"ru": "Аргументы были сохранены по адресу: ",
"zh": "训练参数已保存至:",
},
"info_loading": {
"en": "Loading model...",
"ru": "Загрузка модели...",

View File

@ -7,27 +7,30 @@ if TYPE_CHECKING:
class Manager:
def __init__(self) -> None:
self._elem_dicts: Dict[str, Dict[str, "Component"]] = {}
self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {}
def add_elem_dict(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
r"""
Adds a elem dict.
Adds elements to manager.
"""
self._elem_dicts[tab_name] = elem_dict
for elem_name, elem in elem_dict.items():
elem_id = "{}.{}".format(tab_name, elem_name)
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id
def get_elem_list(self) -> List["Component"]:
r"""
Returns the list of all elements.
"""
return [elem for elem_dict in self._elem_dicts.values() for elem in elem_dict.values()]
return list(self._id_to_elem.values())
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
r"""
Returns an iterator over all elements with their names.
"""
for elem_dict in self._elem_dicts.values():
for elem_name, elem in elem_dict.items():
yield elem_name, elem
for elem_id, elem in self._id_to_elem.items():
yield elem_id.split(".")[-1], elem
def get_elem_by_id(self, elem_id: str) -> "Component":
r"""
@ -35,21 +38,26 @@ class Manager:
Example: top.lang, train.dataset
"""
tab_name, elem_name = elem_id.split(".")
return self._elem_dicts[tab_name][elem_name]
return self._id_to_elem[elem_id]
def get_id_by_elem(self, elem: "Component") -> str:
r"""
Gets id by element.
"""
return self._elem_to_id[elem]
def get_base_elems(self) -> Set["Component"]:
r"""
Gets the base elements that are commonly used.
"""
return {
self._elem_dicts["top"]["lang"],
self._elem_dicts["top"]["model_name"],
self._elem_dicts["top"]["model_path"],
self._elem_dicts["top"]["finetuning_type"],
self._elem_dicts["top"]["adapter_path"],
self._elem_dicts["top"]["quantization_bit"],
self._elem_dicts["top"]["template"],
self._elem_dicts["top"]["rope_scaling"],
self._elem_dicts["top"]["booster"],
self._id_to_elem["top.lang"],
self._id_to_elem["top.model_name"],
self._id_to_elem["top.model_path"],
self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.adapter_path"],
self._id_to_elem["top.quantization_bit"],
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
}

View File

@ -15,7 +15,7 @@ from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_config
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
@ -150,23 +150,21 @@ class Runner:
args["disable_tqdm"] = True
if args["finetuning_type"] == "freeze":
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
args["num_layer_trainable"] = get("train.num_layer_trainable")
args["name_module_trainable"] = get("train.name_module_trainable")
elif args["finetuning_type"] == "lora":
args["lora_rank"] = int(get("train.lora_rank"))
args["lora_alpha"] = int(get("train.lora_alpha"))
args["lora_dropout"] = float(get("train.lora_dropout"))
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout")
args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["additional_target"] = get("train.additional_target") or None
if args["stage"] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = args["quantization_bit"] is None
else:
args["create_new_adapter"] = get("train.create_new_adapter")
if args["use_llama_pro"]:
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
args["num_layer_trainable"] = get("train.num_layer_trainable")
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
@ -305,3 +303,33 @@ class Runner:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
error = self._initialize(data, do_train=True, from_preview=True)
if error:
gr.Warning(error)
return error, gr.Slider(visible=False)
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
config_path = data[self.manager.get_elem_by_id("train.config_path")]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids:
config_dict[elem_id] = value
save_path = save_args(config_path, config_dict)
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
config_dict = load_args(config_path)
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {self.manager.get_elem_by_id("top.lang"): lang}
output_dict: Dict["Component", Any] = {}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value
return output_dict