update webui #1086

This commit is contained in:
hiyouga 2023-10-09 14:50:14 +08:00
parent a683c5b797
commit b8dbec086e
10 changed files with 105 additions and 56 deletions

View File

@ -1,5 +1,5 @@
torch>=1.13.1
transformers>=4.30.0
transformers>=4.31.0
datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0

View File

@ -39,7 +39,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
check_min_version("4.30.0")
check_min_version("4.31.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")

View File

@ -26,7 +26,10 @@ class WebChatModel(ChatModel):
finetuning_type: str,
quantization_bit: str,
template: str,
system_prompt: str
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
@ -42,9 +45,7 @@ class WebChatModel(ChatModel):
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
else:
checkpoint_dir = None
@ -55,7 +56,10 @@ class WebChatModel(ChatModel):
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
template=template,
system_prompt=system_prompt
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None
)
super().__init__(args)

View File

@ -3,8 +3,14 @@ import os
from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
)
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
@ -14,6 +20,14 @@ DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
CKPT_NAMES = [
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
def get_save_dir(*args) -> os.PathLike:
@ -61,10 +75,7 @@ def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
@ -75,6 +86,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f)
except:
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
return {}

View File

@ -57,6 +57,9 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
dataset_dir,
dataset,
cutoff_len,

View File

@ -28,7 +28,10 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"]
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"]
],
[info_box]
).then(

View File

@ -26,10 +26,16 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
system_prompt = gr.Textbox(scale=2)
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
with gr.Row():
flash_attn = gr.Checkbox(value=False)
shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
lang.change(save_config, [lang, model_name, model_path])
model_name.change(
@ -62,5 +68,9 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
system_prompt=system_prompt
system_prompt=system_prompt,
llama_tab=llama_tab,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling
)

View File

@ -55,8 +55,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
flash_attn = gr.Checkbox(value=False)
rope_scaling = gr.Checkbox(value=False)
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
@ -67,8 +65,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2)
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
@ -105,6 +103,9 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
top_elems["quantization_bit"],
top_elems["template"],
top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
training_stage,
dataset_dir,
dataset,
@ -121,8 +122,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
logging_steps,
save_steps,
warmup_steps,
flash_attn,
rope_scaling,
lora_rank,
lora_dropout,
lora_target,
@ -167,8 +166,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
flash_attn=flash_attn,
rope_scaling=rope_scaling,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,

View File

@ -59,12 +59,12 @@ LOCALES = {
},
"quantization_bit": {
"en": {
"label": "Quantization bit (optional)",
"info": "Enable 4/8-bit model quantization."
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization (QLoRA)."
},
"zh": {
"label": "量化等级(非必填)",
"info": "启用 4/8 比特模型量化"
"label": "量化等级",
"info": "启用 4/8 比特模型量化QLoRA"
}
},
"template": {
@ -87,6 +87,38 @@ LOCALES = {
"info": "默认使用的系统提示词"
}
},
"llama_tab": {
"en": {
"label": "Model configurations (LLaMA only)"
},
"zh": {
"label": "模型设置仅LLaMA"
}
},
"flash_attn": {
"en": {
"label": "Use FlashAttention-2"
},
"zh": {
"label": "使用 FlashAttention-2"
}
},
"shift_attn": {
"en": {
"label": "Use shift short attention (S^2-Attn)"
},
"zh": {
"label": "使用 shift short attention (S^2-Attn)"
}
},
"rope_scaling": {
"en": {
"label": "RoPE scaling"
},
"zh": {
"label": "RoPE 插值方法"
}
},
"training_stage": {
"en": {
"label": "Stage",
@ -277,22 +309,6 @@ LOCALES = {
"info": "学习率预热采用的步数。"
}
},
"flash_attn": {
"en": {
"label": "Use FlashAttention-2"
},
"zh": {
"label": "使用 FlashAttention-2"
}
},
"rope_scaling": {
"en": {
"label": "Use RoPE scaling"
},
"zh": {
"label": "使用 RoPE 插值"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
@ -362,11 +378,11 @@ LOCALES = {
"reward_model": {
"en": {
"label": "Reward model",
"info": "Checkpoint of the reward model for PPO training."
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。"
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
}
},
"cmd_preview_btn": {

View File

@ -70,6 +70,9 @@ class Runner:
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
@ -86,8 +89,6 @@ class Runner:
logging_steps: int,
save_steps: int,
warmup_steps: int,
flash_attn: bool,
rope_scaling: bool,
lora_rank: int,
lora_dropout: float,
lora_target: str,
@ -97,9 +98,7 @@ class Runner:
output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
else:
checkpoint_dir = None
@ -119,6 +118,9 @@ class Runner:
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
cutoff_len=cutoff_len,
@ -132,8 +134,6 @@ class Runner:
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
flash_attn=flash_attn,
rope_scaling="linear" if rope_scaling else None,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
@ -168,6 +168,9 @@ class Runner:
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
dataset_dir: str,
dataset: List[str],
cutoff_len: int,
@ -179,9 +182,7 @@ class Runner:
temperature: float
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
else:
checkpoint_dir = None
@ -202,6 +203,9 @@ class Runner:
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
cutoff_len=cutoff_len,