improve web ui

This commit is contained in:
hiyouga 2024-01-10 12:37:45 +08:00
parent 05ed4e8028
commit 1653c22438
4 changed files with 26 additions and 14 deletions

View File

@ -79,6 +79,10 @@ def get_current_device() -> torch.device:
return torch.device(device)
def get_device_count() -> int:
return torch.cuda.device_count()
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.

View File

@ -37,7 +37,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0")
max_samples = gr.Textbox(value="100000")
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
elem_dict.update(dict(
@ -68,13 +68,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
with gr.Column():
train_on_prompt = gr.Checkbox(value=False)
sft_packing = gr.Checkbox(value=False)
upcast_layernorm = gr.Checkbox(value=False)
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, train_on_prompt, upcast_layernorm})
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
elem_dict.update(dict(
extra_tab=extra_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
neftune_alpha=neftune_alpha, sft_packing=sft_packing, upcast_layernorm=upcast_layernorm
))
with gr.Accordion(label="LoRA config", open=False) as lora_tab:

View File

@ -325,14 +325,14 @@ LOCALES = {
"info": "嵌入向量所添加的噪声大小。"
}
},
"train_on_prompt": {
"sft_packing": {
"en": {
"label": "Train on prompt",
"info": "Compute loss on the prompt tokens in supervised fine-tuning."
"label": "Pack sequences",
"info": "Pack sequences into samples of fixed length in supervised fine-tuning."
},
"zh": {
"label": "计算输入损失",
"info": "监督微调时候计算输入序列的损失"
"label": "序列打包",
"info": "有监督微调阶段将序列打包为相同长度的样本"
}
},
"upcast_layernorm": {
@ -342,7 +342,7 @@ LOCALES = {
},
"zh": {
"label": "缩放归一化层",
"info": "将归一化层权重缩放至 32 位浮点数"
"info": "将归一化层权重缩放至 32 位精度"
}
},
"lora_tab": {
@ -665,6 +665,10 @@ ALERTS = {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
},
"err_device_count": {
"en": "Multiple GPUs are not supported yet.",
"zh": "尚不支持多 GPU 训练。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"

View File

@ -12,7 +12,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.misc import get_device_count, torch_gc
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
@ -67,6 +67,9 @@ class Runner:
if self.demo_mode and (not from_preview):
return ALERTS["err_demo"][lang]
if not from_preview and get_device_count() > 1:
return ALERTS["err_device_count"][lang]
self.aborted = False
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
@ -119,16 +122,17 @@ class Runner:
save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha") or None,
train_on_prompt=get("train.train_on_prompt"),
sft_packing=get("train.sft_packing"),
upcast_layernorm=get("train.upcast_layernorm"),
lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") or None,
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16")
)
args[get("train.compute_type")] = True
args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]: