update webui #1086

This commit is contained in:
hiyouga 2023-10-09 14:50:14 +08:00
parent a683c5b797
commit b8dbec086e
10 changed files with 105 additions and 56 deletions

View File

@ -1,5 +1,5 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.30.0 transformers>=4.31.0
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0

View File

@ -39,7 +39,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
check_min_version("4.30.0") check_min_version("4.31.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")

View File

@ -26,7 +26,10 @@ class WebChatModel(ChatModel):
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
system_prompt: str system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str
): ):
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
@ -42,9 +45,7 @@ class WebChatModel(ChatModel):
return return
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else: else:
checkpoint_dir = None checkpoint_dir = None
@ -55,7 +56,10 @@ class WebChatModel(ChatModel):
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None, quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
template=template, template=template,
system_prompt=system_prompt system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None
) )
super().__init__(args) super().__init__(args)

View File

@ -3,8 +3,14 @@ import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import gradio as gr import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.utils import (
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
)
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
@ -14,6 +20,14 @@ DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves" DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config" USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json" DATA_CONFIG = "dataset_info.json"
CKPT_NAMES = [
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
def get_save_dir(*args) -> os.PathLike: def get_save_dir(*args) -> os.PathLike:
@ -61,10 +75,7 @@ def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
for checkpoint in os.listdir(save_dir): for checkpoint in os.listdir(save_dir):
if ( if (
os.path.isdir(os.path.join(save_dir, checkpoint)) os.path.isdir(os.path.join(save_dir, checkpoint))
and any([ and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
): ):
checkpoints.append(checkpoint) checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints) return gr.update(value=[], choices=checkpoints)
@ -75,6 +86,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except: except:
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
return {} return {}

View File

@ -57,6 +57,9 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["system_prompt"], top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
dataset_dir, dataset_dir,
dataset, dataset,
cutoff_len, cutoff_len,

View File

@ -28,7 +28,10 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["system_prompt"] top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"]
], ],
[info_box] [info_box]
).then( ).then(

View File

@ -26,10 +26,16 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
system_prompt = gr.Textbox(scale=2) system_prompt = gr.Textbox(scale=2)
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
with gr.Row():
flash_attn = gr.Checkbox(value=False)
shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
lang.change(save_config, [lang, model_name, model_path]) lang.change(save_config, [lang, model_name, model_path])
model_name.change( model_name.change(
@ -62,5 +68,9 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
system_prompt=system_prompt system_prompt=system_prompt,
llama_tab=llama_tab,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling
) )

View File

@ -55,8 +55,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
flash_attn = gr.Checkbox(value=False)
rope_scaling = gr.Checkbox(value=False)
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
@ -67,8 +65,8 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=2) reward_model = gr.Dropdown(scale=3)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
refresh_btn.click( refresh_btn.click(
@ -105,6 +103,9 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["system_prompt"], top_elems["system_prompt"],
top_elems["flash_attn"],
top_elems["shift_attn"],
top_elems["rope_scaling"],
training_stage, training_stage,
dataset_dir, dataset_dir,
dataset, dataset,
@ -121,8 +122,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
logging_steps, logging_steps,
save_steps, save_steps,
warmup_steps, warmup_steps,
flash_attn,
rope_scaling,
lora_rank, lora_rank,
lora_dropout, lora_dropout,
lora_target, lora_target,
@ -167,8 +166,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
flash_attn=flash_attn,
rope_scaling=rope_scaling,
lora_tab=lora_tab, lora_tab=lora_tab,
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,

View File

@ -59,12 +59,12 @@ LOCALES = {
}, },
"quantization_bit": { "quantization_bit": {
"en": { "en": {
"label": "Quantization bit (optional)", "label": "Quantization bit",
"info": "Enable 4/8-bit model quantization." "info": "Enable 4/8-bit model quantization (QLoRA)."
}, },
"zh": { "zh": {
"label": "量化等级(非必填)", "label": "量化等级",
"info": "启用 4/8 比特模型量化" "info": "启用 4/8 比特模型量化QLoRA"
} }
}, },
"template": { "template": {
@ -87,6 +87,38 @@ LOCALES = {
"info": "默认使用的系统提示词" "info": "默认使用的系统提示词"
} }
}, },
"llama_tab": {
"en": {
"label": "Model configurations (LLaMA only)"
},
"zh": {
"label": "模型设置仅LLaMA"
}
},
"flash_attn": {
"en": {
"label": "Use FlashAttention-2"
},
"zh": {
"label": "使用 FlashAttention-2"
}
},
"shift_attn": {
"en": {
"label": "Use shift short attention (S^2-Attn)"
},
"zh": {
"label": "使用 shift short attention (S^2-Attn)"
}
},
"rope_scaling": {
"en": {
"label": "RoPE scaling"
},
"zh": {
"label": "RoPE 插值方法"
}
},
"training_stage": { "training_stage": {
"en": { "en": {
"label": "Stage", "label": "Stage",
@ -277,22 +309,6 @@ LOCALES = {
"info": "学习率预热采用的步数。" "info": "学习率预热采用的步数。"
} }
}, },
"flash_attn": {
"en": {
"label": "Use FlashAttention-2"
},
"zh": {
"label": "使用 FlashAttention-2"
}
},
"rope_scaling": {
"en": {
"label": "Use RoPE scaling"
},
"zh": {
"label": "使用 RoPE 插值"
}
},
"lora_tab": { "lora_tab": {
"en": { "en": {
"label": "LoRA configurations" "label": "LoRA configurations"
@ -362,11 +378,11 @@ LOCALES = {
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",
"info": "Checkpoint of the reward model for PPO training." "info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
}, },
"zh": { "zh": {
"label": "奖励模型", "label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。" "info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
} }
}, },
"cmd_preview_btn": { "cmd_preview_btn": {

View File

@ -70,6 +70,9 @@ class Runner:
quantization_bit: str, quantization_bit: str,
template: str, template: str,
system_prompt: str, system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
training_stage: str, training_stage: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
@ -86,8 +89,6 @@ class Runner:
logging_steps: int, logging_steps: int,
save_steps: int, save_steps: int,
warmup_steps: int, warmup_steps: int,
flash_attn: bool,
rope_scaling: bool,
lora_rank: int, lora_rank: int,
lora_dropout: float, lora_dropout: float,
lora_target: str, lora_target: str,
@ -97,9 +98,7 @@ class Runner:
output_dir: str output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
else: else:
checkpoint_dir = None checkpoint_dir = None
@ -119,6 +118,9 @@ class Runner:
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template, template=template,
system_prompt=system_prompt, system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
@ -132,8 +134,6 @@ class Runner:
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
flash_attn=flash_attn,
rope_scaling="linear" if rope_scaling else None,
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
@ -168,6 +168,9 @@ class Runner:
quantization_bit: str, quantization_bit: str,
template: str, template: str,
system_prompt: str, system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
cutoff_len: int, cutoff_len: int,
@ -179,9 +182,7 @@ class Runner:
temperature: float temperature: float
) -> Tuple[str, str, List[str], str, Dict[str, Any]]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints)) output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
else: else:
checkpoint_dir = None checkpoint_dir = None
@ -202,6 +203,9 @@ class Runner:
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template, template=template,
system_prompt=system_prompt, system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
cutoff_len=cutoff_len, cutoff_len=cutoff_len,