some ideas are borrowed from @marko1616
This commit is contained in:
parent
257f643a74
commit
7a086ed333
|
@ -66,6 +66,7 @@ def check_dependencies() -> None:
|
|||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
|
|
@ -20,6 +20,7 @@ from ..extras.misc import use_modelscope
|
|||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
|
@ -33,6 +34,10 @@ def get_config_path() -> os.PathLike:
|
|||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def get_save_path(config_path: str) -> os.PathLike:
|
||||
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
|
@ -52,6 +57,22 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
|||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
|
||||
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return str(get_save_path(config_path))
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
|
|
|
@ -46,8 +46,8 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
|
|
|
@ -27,8 +27,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
|
@ -127,19 +125,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
use_rslora = gr.Checkbox(scale=1)
|
||||
use_dora = gr.Checkbox(scale=1)
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update(
|
||||
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
|
||||
{
|
||||
lora_rank,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
loraplus_lr_ratio,
|
||||
create_new_adapter,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
lora_target,
|
||||
additional_target,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
|
@ -147,10 +156,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
create_new_adapter=create_new_adapter,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
create_new_adapter=create_new_adapter,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
)
|
||||
|
@ -161,13 +171,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
|
@ -177,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
|
||||
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
|
||||
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
|
||||
galore_target = gr.Textbox(value="mlp,attn", scale=3)
|
||||
galore_target = gr.Textbox(value="all", scale=3)
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
elem_dict.update(
|
||||
|
@ -193,13 +196,16 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
arg_save_btn = gr.Button()
|
||||
arg_load_btn = gr.Button()
|
||||
start_btn = gr.Button(variant="primary")
|
||||
stop_btn = gr.Button(variant="stop")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
config_path = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
|
@ -211,20 +217,38 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_elems.add(output_dir)
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, process_bar]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_load_btn.click(
|
||||
engine.runner.load_args,
|
||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
||||
list(input_elems),
|
||||
concurrency_limit=None,
|
||||
)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
arg_save_btn=arg_save_btn,
|
||||
arg_load_btn=arg_load_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
output_box=output_box,
|
||||
|
|
|
@ -38,8 +38,9 @@ class Engine:
|
|||
if not self.pure_chat:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["train.output_dir"] = {"value": "train_" + get_time()}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_" + get_time()}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .common import save_config
|
||||
from .components import (
|
||||
|
@ -14,9 +13,6 @@ from .css import CSS
|
|||
from .engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
||||
|
||||
|
||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
|
||||
|
@ -29,21 +25,21 @@ def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
|||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
engine.manager.add_elem_dict("top", create_top())
|
||||
engine.manager.add_elems("top", create_top())
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.add_elem_dict("train", create_train_tab(engine))
|
||||
engine.manager.add_elems("train", create_train_tab(engine))
|
||||
|
||||
with gr.Tab("Evaluate & Predict"):
|
||||
engine.manager.add_elem_dict("eval", create_eval_tab(engine))
|
||||
engine.manager.add_elems("eval", create_eval_tab(engine))
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
engine.manager.add_elem_dict("infer", create_infer_tab(engine))
|
||||
engine.manager.add_elems("infer", create_infer_tab(engine))
|
||||
|
||||
if not demo_mode:
|
||||
with gr.Tab("Export"):
|
||||
engine.manager.add_elem_dict("export", create_export_tab(engine))
|
||||
engine.manager.add_elems("export", create_export_tab(engine))
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
|
@ -57,10 +53,10 @@ def create_web_demo() -> gr.Blocks:
|
|||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.add_elem_dict("top", dict(lang=lang))
|
||||
engine.manager.add_elems("top", dict(lang=lang))
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.add_elem_dict("infer", dict(chat_box=chat_box, **chat_elems))
|
||||
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
|
|
|
@ -628,18 +628,32 @@ LOCALES = {
|
|||
"info": "LoRA 权重随机丢弃的概率。",
|
||||
},
|
||||
},
|
||||
"lora_target": {
|
||||
"loraplus_lr_ratio": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
|
||||
"label": "LoRA+ LR ratio",
|
||||
"info": "The LR ratio of the B matrices in LoRA.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Модули LoRA (опционально)",
|
||||
"info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
|
||||
"label": "LoRA+ LR коэффициент",
|
||||
"info": "Коэффициент LR матриц B в LoRA.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 作用模块(非必填)",
|
||||
"info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
|
||||
"label": "LoRA+ 学习率比例",
|
||||
"info": "LoRA+ 中 B 矩阵的学习率倍数。",
|
||||
},
|
||||
},
|
||||
"create_new_adapter": {
|
||||
"en": {
|
||||
"label": "Create new adapter",
|
||||
"info": "Create a new adapter with randomly initialized weight upon the existing one.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Создать новый адаптер",
|
||||
"info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "新建适配器",
|
||||
"info": "在现有的适配器上创建一个随机初始化后的新适配器。",
|
||||
},
|
||||
},
|
||||
"use_rslora": {
|
||||
|
@ -670,18 +684,18 @@ LOCALES = {
|
|||
"info": "使用权重分解的 LoRA。",
|
||||
},
|
||||
},
|
||||
"create_new_adapter": {
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "Create new adapter",
|
||||
"info": "Create a new adapter with randomly initialized weight upon the existing one.",
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Создать новый адаптер",
|
||||
"info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
|
||||
"label": "Модули LoRA (опционально)",
|
||||
"info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "新建适配器",
|
||||
"info": "在现有的适配器上创建一个随机初始化后的新适配器。",
|
||||
"label": "LoRA 作用模块(非必填)",
|
||||
"info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"additional_target": {
|
||||
|
@ -849,6 +863,28 @@ LOCALES = {
|
|||
"value": "预览命令",
|
||||
},
|
||||
},
|
||||
"arg_save_btn": {
|
||||
"en": {
|
||||
"value": "Save arguments",
|
||||
},
|
||||
"ru": {
|
||||
"value": "Сохранить аргументы",
|
||||
},
|
||||
"zh": {
|
||||
"value": "保存训练参数",
|
||||
},
|
||||
},
|
||||
"arg_load_btn": {
|
||||
"en": {
|
||||
"value": "Load arguments",
|
||||
},
|
||||
"ru": {
|
||||
"value": "Загрузить аргументы",
|
||||
},
|
||||
"zh": {
|
||||
"value": "载入训练参数",
|
||||
},
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start",
|
||||
|
@ -885,6 +921,20 @@ LOCALES = {
|
|||
"info": "保存结果的路径。",
|
||||
},
|
||||
},
|
||||
"config_path": {
|
||||
"en": {
|
||||
"label": "Config path",
|
||||
"info": "Path to config saving arguments.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Путь к конфигурации",
|
||||
"info": "Путь для сохранения аргументов конфигурации.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "配置路径",
|
||||
"info": "保存训练参数的配置文件路径。",
|
||||
},
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready.",
|
||||
|
@ -1236,6 +1286,11 @@ ALERTS = {
|
|||
"ru": "Неверная схема JSON.",
|
||||
"zh": "Json 格式错误。",
|
||||
},
|
||||
"err_config_not_found": {
|
||||
"en": "Config file is not found.",
|
||||
"ru": "Файл конфигурации не найден.",
|
||||
"zh": "未找到配置文件。",
|
||||
},
|
||||
"warn_no_cuda": {
|
||||
"en": "CUDA environment was not detected.",
|
||||
"ru": "Среда CUDA не обнаружена.",
|
||||
|
@ -1256,6 +1311,11 @@ ALERTS = {
|
|||
"ru": "Завершено.",
|
||||
"zh": "训练完毕。",
|
||||
},
|
||||
"info_config_saved": {
|
||||
"en": "Arguments have been saved at: ",
|
||||
"ru": "Аргументы были сохранены по адресу: ",
|
||||
"zh": "训练参数已保存至:",
|
||||
},
|
||||
"info_loading": {
|
||||
"en": "Loading model...",
|
||||
"ru": "Загрузка модели...",
|
||||
|
|
|
@ -7,27 +7,30 @@ if TYPE_CHECKING:
|
|||
|
||||
class Manager:
|
||||
def __init__(self) -> None:
|
||||
self._elem_dicts: Dict[str, Dict[str, "Component"]] = {}
|
||||
self._id_to_elem: Dict[str, "Component"] = {}
|
||||
self._elem_to_id: Dict["Component", str] = {}
|
||||
|
||||
def add_elem_dict(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
|
||||
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
|
||||
r"""
|
||||
Adds a elem dict.
|
||||
Adds elements to manager.
|
||||
"""
|
||||
self._elem_dicts[tab_name] = elem_dict
|
||||
for elem_name, elem in elem_dict.items():
|
||||
elem_id = "{}.{}".format(tab_name, elem_name)
|
||||
self._id_to_elem[elem_id] = elem
|
||||
self._elem_to_id[elem] = elem_id
|
||||
|
||||
def get_elem_list(self) -> List["Component"]:
|
||||
r"""
|
||||
Returns the list of all elements.
|
||||
"""
|
||||
return [elem for elem_dict in self._elem_dicts.values() for elem in elem_dict.values()]
|
||||
return list(self._id_to_elem.values())
|
||||
|
||||
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
|
||||
r"""
|
||||
Returns an iterator over all elements with their names.
|
||||
"""
|
||||
for elem_dict in self._elem_dicts.values():
|
||||
for elem_name, elem in elem_dict.items():
|
||||
yield elem_name, elem
|
||||
for elem_id, elem in self._id_to_elem.items():
|
||||
yield elem_id.split(".")[-1], elem
|
||||
|
||||
def get_elem_by_id(self, elem_id: str) -> "Component":
|
||||
r"""
|
||||
|
@ -35,21 +38,26 @@ class Manager:
|
|||
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
tab_name, elem_name = elem_id.split(".")
|
||||
return self._elem_dicts[tab_name][elem_name]
|
||||
return self._id_to_elem[elem_id]
|
||||
|
||||
def get_id_by_elem(self, elem: "Component") -> str:
|
||||
r"""
|
||||
Gets id by element.
|
||||
"""
|
||||
return self._elem_to_id[elem]
|
||||
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
r"""
|
||||
Gets the base elements that are commonly used.
|
||||
"""
|
||||
return {
|
||||
self._elem_dicts["top"]["lang"],
|
||||
self._elem_dicts["top"]["model_name"],
|
||||
self._elem_dicts["top"]["model_path"],
|
||||
self._elem_dicts["top"]["finetuning_type"],
|
||||
self._elem_dicts["top"]["adapter_path"],
|
||||
self._elem_dicts["top"]["quantization_bit"],
|
||||
self._elem_dicts["top"]["template"],
|
||||
self._elem_dicts["top"]["rope_scaling"],
|
||||
self._elem_dicts["top"]["booster"],
|
||||
self._id_to_elem["top.lang"],
|
||||
self._id_to_elem["top.model_name"],
|
||||
self._id_to_elem["top.model_path"],
|
||||
self._id_to_elem["top.finetuning_type"],
|
||||
self._id_to_elem["top.adapter_path"],
|
||||
self._id_to_elem["top.quantization_bit"],
|
||||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..extras.constants import TRAINING_STAGES
|
|||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_config
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
|
@ -150,23 +150,21 @@ class Runner:
|
|||
args["disable_tqdm"] = True
|
||||
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = int(get("train.lora_rank"))
|
||||
args["lora_alpha"] = int(get("train.lora_alpha"))
|
||||
args["lora_dropout"] = float(get("train.lora_dropout"))
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_alpha"] = get("train.lora_alpha")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
if args["stage"] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
else:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = ",".join(
|
||||
|
@ -305,3 +303,33 @@ class Runner:
|
|||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
|
||||
|
||||
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
return error, gr.Slider(visible=False)
|
||||
|
||||
config_dict: Dict[str, Any] = {}
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
config_dict[elem_id] = value
|
||||
|
||||
save_path = save_args(config_path, config_dict)
|
||||
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
|
||||
|
||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
||||
config_dict = load_args(config_path)
|
||||
if config_dict is None:
|
||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||
return {self.manager.get_elem_by_id("top.lang"): lang}
|
||||
|
||||
output_dict: Dict["Component", Any] = {}
|
||||
for elem_id, value in config_dict.items():
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
|
|
Loading…
Reference in New Issue