update webui
This commit is contained in:
parent
816d714146
commit
d0842f6828
|
@ -4,7 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
|||
--stage orpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset comparison_gpt4_en \
|
||||
--dataset orca_rlhf \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
|
|
|
@ -21,10 +21,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
|
@ -75,11 +75,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
optim = gr.Textbox(value="adamw_torch")
|
||||
|
||||
with gr.Row():
|
||||
resize_vocab = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
shift_attn = gr.Checkbox()
|
||||
with gr.Column():
|
||||
resize_vocab = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
shift_attn = gr.Checkbox()
|
||||
report_to = gr.Checkbox()
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
|
@ -93,6 +99,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
report_to,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
|
@ -108,6 +115,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
report_to=report_to,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -536,6 +536,20 @@ LOCALES = {
|
|||
"info": "使用 LongLoRA 提出的 shift short attention。",
|
||||
},
|
||||
},
|
||||
"report_to": {
|
||||
"en": {
|
||||
"label": "Enable external logger",
|
||||
"info": "Use TensorBoard or wandb to log experiment.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Включить внешний регистратор",
|
||||
"info": "Использовать TensorBoard или wandb для ведения журнала экспериментов.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "启用外部记录面板",
|
||||
"info": "使用 TensorBoard 或 wandb 记录实验。",
|
||||
},
|
||||
},
|
||||
"freeze_tab": {
|
||||
"en": {
|
||||
"label": "Freeze tuning configurations",
|
||||
|
|
|
@ -80,20 +80,18 @@ class Runner:
|
|||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.aborted = False
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
self.thread = None
|
||||
self.running_data = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
return ALERTS["info_aborted"][lang]
|
||||
else:
|
||||
return finish_info
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
|
@ -141,6 +139,7 @@ class Runner:
|
|||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
shift_attn=get("train.shift_attn"),
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
use_galore=get("train.use_galore"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
|
|
Loading…
Reference in New Issue