support orpo in webui

This commit is contained in:
hiyouga 2024-03-31 18:34:59 +08:00
parent 17bf8a2c3a
commit 5195add324
3 changed files with 22 additions and 4 deletions

View File

@ -169,10 +169,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1) dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2) reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
input_elems.update({dpo_beta, dpo_ftx, reward_model}) input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model)) elem_dict.update(
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
)
with gr.Accordion(open=False) as galore_tab: with gr.Accordion(open=False) as galore_tab:
with gr.Row(): with gr.Row():

View File

@ -757,6 +757,20 @@ LOCALES = {
"info": "DPO-ftx 中 SFT 损失的权重大小。", "info": "DPO-ftx 中 SFT 损失的权重大小。",
}, },
}, },
"orpo_beta": {
"en": {
"label": "ORPO beta",
"info": "Value of the beta parameter in the ORPO loss.",
},
"ru": {
"label": "ORPO бета",
"info": "Значение параметра бета в функции потерь ORPO.",
},
"zh": {
"label": "ORPO beta 参数",
"info": "ORPO 损失函数中 beta 超参数大小。",
},
},
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",

View File

@ -174,10 +174,11 @@ class Runner:
] ]
) )
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
elif args["stage"] == "dpo":
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta") args["dpo_beta"] = get("train.dpo_beta")
args["dpo_ftx"] = get("train.dpo_ftx") args["dpo_ftx"] = get("train.dpo_ftx")
elif args["stage"] == "orpo":
args["orpo_beta"] = get("train.orpo_beta")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo": if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size") args["val_size"] = get("train.val_size")