web UI integrating RLHF
This commit is contained in:
parent
2f2fd55d81
commit
ec94274ca1
|
@ -68,7 +68,7 @@
|
|||
| ---------------------- | -------------- | ----------------- | ---- | ----- |
|
||||
| Pre-Training | ✅ | ✅ | ✅ | ✅ |
|
||||
| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ |
|
||||
| Reward Model Training | | | ✅ | ✅ |
|
||||
| Reward Modeling | | | ✅ | ✅ |
|
||||
| PPO Training | | | ✅ | ✅ |
|
||||
| DPO Training | ✅ | | ✅ | ✅ |
|
||||
|
||||
|
@ -103,7 +103,7 @@
|
|||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward modelling or DPO training:
|
||||
- For reward modeling or DPO training:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
@ -206,7 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
--fp16
|
||||
```
|
||||
|
||||
### Reward Model Training
|
||||
### Reward Modeling
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
|
|
@ -37,7 +37,9 @@ def run_ppo(
|
|||
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm
|
||||
max_grad_norm=training_args.max_grad_norm,
|
||||
seed=training_args.seed,
|
||||
optimize_cuda_cache=True
|
||||
)
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
|
|
|
@ -29,12 +29,14 @@ def load_config() -> Dict[str, Any]:
|
|||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
return {"last_model": "", "path_dict": {}}
|
||||
return {"lang": "", "last_model": "", "path_dict": {}}
|
||||
|
||||
|
||||
def save_config(model_name: str, model_path: str) -> None:
|
||||
def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from llmtuner.webui.components.top import create_top
|
||||
from llmtuner.webui.components.sft import create_sft_tab
|
||||
from llmtuner.webui.components.train import create_train_tab
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.export import create_export_tab
|
||||
|
|
|
@ -20,22 +20,25 @@ def create_top() -> Dict[str, "Component"]:
|
|||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(["", "8", "4"], scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
||||
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||
source_prefix = gr.Textbox(scale=2)
|
||||
|
||||
lang.change(save_config, [lang, model_name, model_path])
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
) # do not save config since the below line will save
|
||||
model_path.change(save_config, [model_name, model_path])
|
||||
|
||||
model_path.change(save_config, [lang, model_name, model_path])
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
|
@ -43,7 +46,9 @@ def create_top() -> Dict[str, "Component"]:
|
|||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
)
|
||||
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False)
|
||||
refresh_btn.click(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
|
|
|
@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
|
||||
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
from llmtuner.webui.runner import Runner
|
||||
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
|
@ -40,7 +40,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
||||
)
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
@ -60,6 +60,20 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
lora_target = gr.Textbox(scale=2)
|
||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
rlhf_method = gr.Dropdown(choices=["None", "Reward Modeling", "PPO", "DPO"], value="None", scale=1)
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
|
||||
reward_model = gr.Dropdown(scale=2)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
refresh_btn.click(
|
||||
list_checkpoint,
|
||||
[top_elems["model_name"], top_elems["finetuning_type"]],
|
||||
[reward_model],
|
||||
queue=False
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
|
@ -79,7 +93,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_list = [
|
||||
input_components = [
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
|
@ -108,16 +122,19 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
lora_dropout,
|
||||
lora_target,
|
||||
resume_lora_training,
|
||||
rlhf_method,
|
||||
dpo_beta,
|
||||
reward_model,
|
||||
output_dir
|
||||
]
|
||||
|
||||
output_list = [
|
||||
output_components = [
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
|
||||
cmd_preview_btn.click(runner.preview_train, input_list, output_list)
|
||||
start_btn.click(runner.run_train, input_list, output_list)
|
||||
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
|
||||
start_btn.click(runner.run_train, input_components, output_components)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
process_bar.change(
|
||||
|
@ -152,6 +169,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
resume_lora_training=resume_lora_training,
|
||||
rlhf_tab=rlhf_tab,
|
||||
rlhf_method=rlhf_method,
|
||||
dpo_beta=dpo_beta,
|
||||
reward_model=reward_model,
|
||||
refresh_btn=refresh_btn,
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
|
@ -3,7 +3,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
from llmtuner.webui.components import (
|
||||
create_top,
|
||||
create_sft_tab,
|
||||
create_train_tab,
|
||||
create_eval_tab,
|
||||
create_infer_tab,
|
||||
create_export_tab,
|
||||
|
@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks:
|
|||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
top_elems = create_top()
|
||||
|
||||
with gr.Tab("SFT"):
|
||||
sft_elems = create_sft_tab(top_elems, runner)
|
||||
with gr.Tab("Train"):
|
||||
train_elems = create_train_tab(top_elems, runner)
|
||||
|
||||
with gr.Tab("Evaluate"):
|
||||
eval_elems = create_eval_tab(top_elems, runner)
|
||||
|
@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks:
|
|||
with gr.Tab("Export"):
|
||||
export_elems = create_export_tab(top_elems)
|
||||
|
||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
|
||||
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
|
||||
manager = Manager(elem_list)
|
||||
|
||||
demo.load(
|
||||
|
@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks:
|
|||
chat_model = WebChatModel(lazy_init=False)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="")
|
||||
|
||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||
|
||||
|
|
|
@ -335,6 +335,44 @@ LOCALES = {
|
|||
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
||||
}
|
||||
},
|
||||
"rlhf_tab": {
|
||||
"en": {
|
||||
"label": "RLHF configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RLHF 参数设置"
|
||||
}
|
||||
},
|
||||
"rlhf_method": {
|
||||
"en": {
|
||||
"label": "RLHF method",
|
||||
"info": "The RLHF algorithm to adopt."
|
||||
},
|
||||
"zh": {
|
||||
"label": "RLHF 方法",
|
||||
"info": "RLHF 阶段使用的算法。"
|
||||
}
|
||||
},
|
||||
"dpo_beta": {
|
||||
"en": {
|
||||
"label": "DPO beta",
|
||||
"info": "Value of the beta parameter in the DPO loss."
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO beta 参数",
|
||||
"info": "DPO 损失函数中 beta 超参数大小。"
|
||||
}
|
||||
},
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Checkpoint of the reward model for PPO training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中奖励模型的断点路径。"
|
||||
}
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview command"
|
||||
|
|
|
@ -12,12 +12,18 @@ class Manager:
|
|||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||
self.elem_list = elem_list
|
||||
|
||||
def gen_refresh(self) -> Dict[str, Any]:
|
||||
def gen_refresh(self, lang: str) -> Dict[str, Any]:
|
||||
refresh_dict = {
|
||||
"dataset": {"choices": list_dataset()["choices"]},
|
||||
"output_dir": {"value": get_time()}
|
||||
}
|
||||
|
||||
user_config = load_config()
|
||||
if lang:
|
||||
refresh_dict["lang"] = {"value": lang}
|
||||
else:
|
||||
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
|
||||
|
||||
if user_config["last_model"]:
|
||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
@ -26,10 +32,12 @@ class Manager:
|
|||
|
||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||
update_dict = {}
|
||||
refresh_dict = self.gen_refresh()
|
||||
refresh_dict = self.gen_refresh(lang)
|
||||
|
||||
for elems in self.elem_list:
|
||||
for name, component in elems.items():
|
||||
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
|
||||
update_dict[component] = gr.update(
|
||||
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
|
||||
)
|
||||
|
||||
return update_dict
|
||||
|
|
|
@ -91,6 +91,9 @@ class Runner:
|
|||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
rlhf_method: str,
|
||||
dpo_beta: float,
|
||||
reward_model: str,
|
||||
output_dir: str
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
|
@ -109,7 +112,7 @@ class Runner:
|
|||
overwrite_cache=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
|
@ -134,6 +137,21 @@ class Runner:
|
|||
output_dir=output_dir
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if rlhf_method == "Reward Modeling":
|
||||
args["stage"] = "rm"
|
||||
args["resume_lora_training"] = False
|
||||
elif rlhf_method == "PPO":
|
||||
args["stage"] = "ppo"
|
||||
args["resume_lora_training"] = False
|
||||
args["reward_model"] = reward_model
|
||||
args["padding_side"] = "left"
|
||||
val_size = 0
|
||||
elif rlhf_method == "DPO":
|
||||
args["stage"] = "dpo"
|
||||
args["resume_lora_training"] = False
|
||||
args["dpo_beta"] = dpo_beta
|
||||
|
||||
if val_size > 1e-6:
|
||||
args["val_size"] = val_size
|
||||
args["evaluation_strategy"] = "steps"
|
||||
|
@ -176,7 +194,7 @@ class Runner:
|
|||
predict_with_generate=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
|
|
|
@ -63,6 +63,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
if args.get("do_train", None):
|
||||
args["plot_loss"] = True
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
||||
for k, v in args.items():
|
||||
|
|
Loading…
Reference in New Issue