web UI integrating RLHF

This commit is contained in:
hiyouga 2023-08-14 10:48:47 +08:00
parent 2f2fd55d81
commit ec94274ca1
11 changed files with 128 additions and 32 deletions

View File

@ -68,7 +68,7 @@
| ---------------------- | -------------- | ----------------- | ---- | ----- |
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
| Reward Model Training | | | ✅ | ✅ |
| Reward Modeling | | | ✅ | ✅ |
| PPO Training | | | ✅ | ✅ |
| DPO Training | ✅ | | ✅ | ✅ |
@ -103,7 +103,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling or DPO training:
- For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@ -206,7 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### Reward Model Training
### Reward Modeling
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \

View File

@ -37,7 +37,9 @@ def run_ppo(
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)

View File

@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"last_model": "", "path_dict": {}}
return {"lang": "", "last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None:
def save_config(lang: str, model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
user_config["lang"] = lang or user_config["lang"]
if model_name:
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)

View File

@ -1,5 +1,5 @@
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab

View File

@ -20,22 +20,25 @@ def create_top() -> Dict[str, "Component"]:
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(["", "8", "4"], scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
model_path.change(save_config, [lang, model_name, model_path])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
@ -43,7 +46,9 @@ def create_top() -> Dict[str, "Component"]:
can_quantize, [finetuning_type], [quantization_bit]
)
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
refresh_btn.click(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)
return dict(
lang=lang,

View File

@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@ -12,7 +12,7 @@ if TYPE_CHECKING:
from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
@ -40,7 +40,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
@ -60,6 +60,20 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
rlhf_method = gr.Dropdown(choices=["None", "Reward Modeling", "PPO", "DPO"], value="None", scale=1)
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]],
[reward_model],
queue=False
)
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button()
@ -79,7 +93,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_list = [
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
@ -108,16 +122,19 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_dropout,
lora_target,
resume_lora_training,
rlhf_method,
dpo_beta,
reward_model,
output_dir
]
output_list = [
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_train, input_list, output_list)
start_btn.click(runner.run_train, input_list, output_list)
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
process_bar.change(
@ -152,6 +169,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
rlhf_method=rlhf_method,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,

View File

@ -3,7 +3,7 @@ from transformers.utils.versions import require_version
from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_train_tab,
create_eval_tab,
create_infer_tab,
create_export_tab,
@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
with gr.Tab("SFT"):
sft_elems = create_sft_tab(top_elems, runner)
with gr.Tab("Train"):
train_elems = create_train_tab(top_elems, runner)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(
@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
lang = gr.Dropdown(choices=["en", "zh"], value="")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)

View File

@ -335,6 +335,44 @@ LOCALES = {
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
},
"rlhf_method": {
"en": {
"label": "RLHF method",
"info": "The RLHF algorithm to adopt."
},
"zh": {
"label": "RLHF 方法",
"info": "RLHF 阶段使用的算法。"
}
},
"dpo_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Checkpoint of the reward model for PPO training."
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"

View File

@ -12,12 +12,18 @@ class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
def gen_refresh(self, lang: str) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
}
user_config = load_config()
if lang:
refresh_dict["lang"] = {"value": lang}
else:
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
@ -26,10 +32,12 @@ class Manager:
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {}
refresh_dict = self.gen_refresh()
refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
update_dict[component] = gr.update(
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
)
return update_dict

View File

@ -91,6 +91,9 @@ class Runner:
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
rlhf_method: str,
dpo_beta: float,
reward_model: str,
output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
@ -109,7 +112,7 @@ class Runner:
overwrite_cache=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
@ -134,6 +137,21 @@ class Runner:
output_dir=output_dir
)
args[compute_type] = True
if rlhf_method == "Reward Modeling":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif rlhf_method == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif rlhf_method == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps"
@ -176,7 +194,7 @@ class Runner:
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,

View File

@ -63,7 +63,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
def gen_cmd(args: Dict[str, Any]) -> str:
args["plot_loss"] = True
if args.get("do_train", None):
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
for k, v in args.items():
if v is not None and v != "":