update webui #1086
This commit is contained in:
parent
a683c5b797
commit
b8dbec086e
|
@ -1,5 +1,5 @@
|
|||
torch>=1.13.1
|
||||
transformers>=4.30.0
|
||||
transformers>=4.31.0
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
|
|
|
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.30.0")
|
||||
check_min_version("4.31.0")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
|
|
|
@ -26,7 +26,10 @@ class WebChatModel(ChatModel):
|
|||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str
|
||||
system_prompt: str,
|
||||
flash_attn: bool,
|
||||
shift_attn: bool,
|
||||
rope_scaling: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
|
@ -42,9 +45,7 @@ class WebChatModel(ChatModel):
|
|||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
|
@ -55,7 +56,10 @@ class WebChatModel(ChatModel):
|
|||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt
|
||||
system_prompt=system_prompt,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
|
|
|
@ -3,8 +3,14 @@ import os
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
|
||||
|
@ -14,6 +20,14 @@ DEFAULT_DATA_DIR = "data"
|
|||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
CKPT_NAMES = [
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
|
@ -61,10 +75,7 @@ def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
|||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([
|
||||
os.path.isfile(os.path.join(save_dir, checkpoint, name))
|
||||
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
|
||||
])
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
|
@ -75,6 +86,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
|||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
|
||||
return {}
|
||||
|
||||
|
||||
|
|
|
@ -57,6 +57,9 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
cutoff_len,
|
||||
|
|
|
@ -28,7 +28,10 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
|
|||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"]
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"]
|
||||
],
|
||||
[info_box]
|
||||
).then(
|
||||
|
|
|
@ -26,10 +26,16 @@ def create_top() -> Dict[str, "Component"]:
|
|||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||
system_prompt = gr.Textbox(scale=2)
|
||||
|
||||
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
|
||||
with gr.Row():
|
||||
flash_attn = gr.Checkbox(value=False)
|
||||
shift_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
lang.change(save_config, [lang, model_name, model_path])
|
||||
|
||||
model_name.change(
|
||||
|
@ -62,5 +68,9 @@ def create_top() -> Dict[str, "Component"]:
|
|||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
system_prompt=system_prompt
|
||||
system_prompt=system_prompt,
|
||||
llama_tab=llama_tab,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling
|
||||
)
|
||||
|
|
|
@ -55,8 +55,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
flash_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Checkbox(value=False)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
|
@ -67,8 +65,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
|
||||
reward_model = gr.Dropdown(scale=2)
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=3)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
refresh_btn.click(
|
||||
|
@ -105,6 +103,9 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
top_elems["flash_attn"],
|
||||
top_elems["shift_attn"],
|
||||
top_elems["rope_scaling"],
|
||||
training_stage,
|
||||
dataset_dir,
|
||||
dataset,
|
||||
|
@ -121,8 +122,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
flash_attn,
|
||||
rope_scaling,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
|
@ -167,8 +166,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
|||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
flash_attn=flash_attn,
|
||||
rope_scaling=rope_scaling,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
|
|
|
@ -59,12 +59,12 @@ LOCALES = {
|
|||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization (QLoRA)."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级(非必填)",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
"label": "量化等级",
|
||||
"info": "启用 4/8 比特模型量化(QLoRA)。"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
|
@ -87,6 +87,38 @@ LOCALES = {
|
|||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"llama_tab": {
|
||||
"en": {
|
||||
"label": "Model configurations (LLaMA only)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型设置(仅LLaMA)"
|
||||
}
|
||||
},
|
||||
"flash_attn": {
|
||||
"en": {
|
||||
"label": "Use FlashAttention-2"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 FlashAttention-2"
|
||||
}
|
||||
},
|
||||
"shift_attn": {
|
||||
"en": {
|
||||
"label": "Use shift short attention (S^2-Attn)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 shift short attention (S^2-Attn)"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RoPE 插值方法"
|
||||
}
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
|
@ -277,22 +309,6 @@ LOCALES = {
|
|||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"flash_attn": {
|
||||
"en": {
|
||||
"label": "Use FlashAttention-2"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 FlashAttention-2"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "Use RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 RoPE 插值"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations"
|
||||
|
@ -362,11 +378,11 @@ LOCALES = {
|
|||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Checkpoint of the reward model for PPO training."
|
||||
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中奖励模型的断点路径。"
|
||||
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
|
||||
}
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
|
|
|
@ -70,6 +70,9 @@ class Runner:
|
|||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
flash_attn: bool,
|
||||
shift_attn: bool,
|
||||
rope_scaling: str,
|
||||
training_stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
|
@ -86,8 +89,6 @@ class Runner:
|
|||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
flash_attn: bool,
|
||||
rope_scaling: bool,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
|
@ -97,9 +98,7 @@ class Runner:
|
|||
output_dir: str
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
|
@ -119,6 +118,9 @@ class Runner:
|
|||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
cutoff_len=cutoff_len,
|
||||
|
@ -132,8 +134,6 @@ class Runner:
|
|||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
flash_attn=flash_attn,
|
||||
rope_scaling="linear" if rope_scaling else None,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
|
@ -168,6 +168,9 @@ class Runner:
|
|||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
flash_attn: bool,
|
||||
shift_attn: bool,
|
||||
rope_scaling: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
cutoff_len: int,
|
||||
|
@ -179,9 +182,7 @@ class Runner:
|
|||
temperature: float
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
@ -202,6 +203,9 @@ class Runner:
|
|||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
cutoff_len=cutoff_len,
|
||||
|
|
Loading…
Reference in New Issue