update webui

This commit is contained in:
hiyouga 2024-04-01 16:23:28 +08:00
parent 816d714146
commit d0842f6828
4 changed files with 36 additions and 15 deletions

View File

@ -4,7 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage orpo \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset comparison_gpt4_en \
--dataset orca_rlhf \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \

View File

@ -21,10 +21,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset})
@ -75,11 +75,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim = gr.Textbox(value="adamw_torch")
with gr.Row():
resize_vocab = gr.Checkbox()
packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
shift_attn = gr.Checkbox()
with gr.Column():
resize_vocab = gr.Checkbox()
packing = gr.Checkbox()
with gr.Column():
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
with gr.Column():
shift_attn = gr.Checkbox()
report_to = gr.Checkbox()
input_elems.update(
{
@ -93,6 +99,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm,
use_llama_pro,
shift_attn,
report_to,
}
)
elem_dict.update(
@ -108,6 +115,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
report_to=report_to,
)
)

View File

@ -536,6 +536,20 @@ LOCALES = {
"info": "使用 LongLoRA 提出的 shift short attention。",
},
},
"report_to": {
"en": {
"label": "Enable external logger",
"info": "Use TensorBoard or wandb to log experiment.",
},
"ru": {
"label": "Включить внешний регистратор",
"info": "Использовать TensorBoard или wandb для ведения журнала экспериментов.",
},
"zh": {
"label": "启用外部记录面板",
"info": "使用 TensorBoard 或 wandb 记录实验。",
},
},
"freeze_tab": {
"en": {
"label": "Freeze tuning configurations",

View File

@ -80,20 +80,18 @@ class Runner:
if not from_preview and not is_torch_cuda_available():
gr.Warning(ALERTS["warn_no_cuda"][lang])
self.aborted = False
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
self.thread = None
self.running_data = None
self.aborted = False
self.running = False
self.running_data = None
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info
return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@ -141,6 +139,7 @@ class Runner:
upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),